diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactId.java b/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactId.java index 9e11e555783018e242eef89a7141e8565f68ea12..21293599febac3854411de04de5ace41017165e4 100644 --- a/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactId.java +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactId.java @@ -6,7 +6,7 @@ import javax.annotation.concurrent.Immutable; /** * Type-safe wrapper for an integer that uniquely identifies a contact within - * the scope of a single node. + * the scope of the local device. */ @Immutable @NotNullByDefault diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactManager.java b/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactManager.java index e6be9a4d2f61956d238ab6e147761e766b307a33..dd3db0503873264def7f6f8b876c6c06a6497c77 100644 --- a/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactManager.java +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactManager.java @@ -23,13 +23,21 @@ public interface ContactManager { void registerRemoveContactHook(RemoveContactHook hook); /** - * Stores a contact within the given transaction associated with the given - * local and remote pseudonyms, and returns an ID for the contact. + * Stores a contact associated with the given local and remote pseudonyms, + * derives and stores transport keys for each transport, and returns an ID + * for the contact. */ ContactId addContact(Transaction txn, Author remote, AuthorId local, SecretKey master, long timestamp, boolean alice, boolean verified, boolean active) throws DbException; + /** + * Stores a contact associated with the given local and remote pseudonyms + * and returns an ID for the contact. + */ + ContactId addContact(Transaction txn, Author remote, AuthorId local, + boolean verified, boolean active) throws DbException; + /** * Stores a contact associated with the given local and remote pseudonyms, * and returns an ID for the contact. diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/crypto/TransportCrypto.java b/bramble-api/src/main/java/org/briarproject/bramble/api/crypto/TransportCrypto.java index 6385d1f015bbf7d0dff30d0c03390caf7af11196..cbd6449b48f3f0c9a64927ee1a7a7169899eef28 100644 --- a/bramble-api/src/main/java/org/briarproject/bramble/api/crypto/TransportCrypto.java +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/crypto/TransportCrypto.java @@ -14,9 +14,10 @@ public interface TransportCrypto { * rotation period from the given master secret. * * @param alice whether the keys are for use by Alice or Bob. + * @param active whether the keys are usable for outgoing streams. */ TransportKeys deriveTransportKeys(TransportId t, SecretKey master, - long rotationPeriod, boolean alice); + long rotationPeriod, boolean alice, boolean active); /** * Rotates the given transport keys to the given rotation period. If the diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/db/DatabaseComponent.java b/bramble-api/src/main/java/org/briarproject/bramble/api/db/DatabaseComponent.java index 08fbe45405b0b5c3a1011e4780156a7d157c6219..ea05938d49c5a24fe0d0da14bd40d718e2605824 100644 --- a/bramble-api/src/main/java/org/briarproject/bramble/api/db/DatabaseComponent.java +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/db/DatabaseComponent.java @@ -18,6 +18,8 @@ import org.briarproject.bramble.api.sync.MessageId; import org.briarproject.bramble.api.sync.MessageStatus; import org.briarproject.bramble.api.sync.Offer; import org.briarproject.bramble.api.sync.Request; +import org.briarproject.bramble.api.transport.KeySet; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.TransportKeys; import java.util.Collection; @@ -102,10 +104,17 @@ public interface DatabaseComponent { throws DbException; /** - * Stores transport keys for a newly added contact. + * Stores the given transport keys, optionally binding them to the given + * contact, and returns a key set ID. */ - void addTransportKeys(Transaction txn, ContactId c, TransportKeys k) - throws DbException; + KeySetId addTransportKeys(Transaction txn, @Nullable ContactId c, + TransportKeys k) throws DbException; + + /** + * Binds the given keys for the given transport to the given contact. + */ + void bindTransportKeys(Transaction txn, ContactId c, TransportId t, + KeySetId k) throws DbException; /** * Returns true if the database contains the given contact for the given @@ -394,15 +403,14 @@ public interface DatabaseComponent { * <p/> * Read-only. */ - Map<ContactId, TransportKeys> getTransportKeys(Transaction txn, - TransportId t) throws DbException; + Collection<KeySet> getTransportKeys(Transaction txn, TransportId t) + throws DbException; /** - * Increments the outgoing stream counter for the given contact and - * transport in the given rotation period . + * Increments the outgoing stream counter for the given transport keys. */ - void incrementStreamCounter(Transaction txn, ContactId c, TransportId t, - long rotationPeriod) throws DbException; + void incrementStreamCounter(Transaction txn, TransportId t, KeySetId k) + throws DbException; /** * Merges the given metadata with the existing metadata for the given @@ -472,6 +480,12 @@ public interface DatabaseComponent { */ void removeTransport(Transaction txn, TransportId t) throws DbException; + /** + * Removes the given transport keys from the database. + */ + void removeTransportKeys(Transaction txn, TransportId t, KeySetId k) + throws DbException; + /** * Marks the given contact as verified. */ @@ -507,15 +521,21 @@ public interface DatabaseComponent { Collection<MessageId> dependencies) throws DbException; /** - * Sets the reordering window for the given contact and transport in the + * Sets the reordering window for the given key set and transport in the * given rotation period. */ - void setReorderingWindow(Transaction txn, ContactId c, TransportId t, + void setReorderingWindow(Transaction txn, KeySetId k, TransportId t, long rotationPeriod, long base, byte[] bitmap) throws DbException; + /** + * Marks the given transport keys as usable for outgoing streams. + */ + void setTransportKeysActive(Transaction txn, TransportId t, KeySetId k) + throws DbException; + /** * Stores the given transport keys, deleting any keys they have replaced. */ - void updateTransportKeys(Transaction txn, - Map<ContactId, TransportKeys> keys) throws DbException; + void updateTransportKeys(Transaction txn, Collection<KeySet> keys) + throws DbException; } diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeyManager.java b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeyManager.java index 97afdc133e1a672ecfcd5bda9959aa64ca619e2f..065e5d2bc3edb66d7f2ae63cee0602dfbb304f30 100644 --- a/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeyManager.java +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeyManager.java @@ -6,6 +6,8 @@ import org.briarproject.bramble.api.db.DbException; import org.briarproject.bramble.api.db.Transaction; import org.briarproject.bramble.api.plugin.TransportId; +import java.util.Map; + import javax.annotation.Nullable; /** @@ -16,13 +18,51 @@ public interface KeyManager { /** * Informs the key manager that a new contact has been added. Derives and - * stores transport keys for communicating with the contact. + * stores a set of transport keys for communicating with the contact over + * each transport. + * <p/> * {@link StreamContext StreamContexts} for the contact can be created * after this method has returned. */ void addContact(Transaction txn, ContactId c, SecretKey master, long timestamp, boolean alice) throws DbException; + /** + * Derives and stores a set of unbound transport keys for each transport + * and returns the key set IDs. + * <p/> + * The keys must be bound before they can be used for incoming streams, + * and also activated before they can be used for outgoing streams. + */ + Map<TransportId, KeySetId> addUnboundKeys(Transaction txn, SecretKey master, + long timestamp, boolean alice) throws DbException; + + /** + * Binds the given transport keys to the given contact. + */ + void bindKeys(Transaction txn, ContactId c, Map<TransportId, KeySetId> keys) + throws DbException; + + /** + * Marks the given transport keys as usable for outgoing streams. Keys must + * be bound before they are activated. + */ + void activateKeys(Transaction txn, Map<TransportId, KeySetId> keys) + throws DbException; + + /** + * Removes the given transport keys, which must not have been bound, from + * the manager and the database. + */ + void removeKeys(Transaction txn, Map<TransportId, KeySetId> keys) + throws DbException; + + /** + * Returns true if we have keys that can be used for outgoing streams to + * the given contact over the given transport. + */ + boolean canSendOutgoingStreams(ContactId c, TransportId t); + /** * Returns a {@link StreamContext} for sending a stream to the given * contact over the given transport, or null if an error occurs or the diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySet.java b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySet.java new file mode 100644 index 0000000000000000000000000000000000000000..9cc8f63c294a6cd0b4674944e51307f72058ddf6 --- /dev/null +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySet.java @@ -0,0 +1,51 @@ +package org.briarproject.bramble.api.transport; + +import org.briarproject.bramble.api.contact.ContactId; +import org.briarproject.bramble.api.nullsafety.NotNullByDefault; + +import javax.annotation.Nullable; +import javax.annotation.concurrent.Immutable; + +/** + * A set of transport keys for communicating with a contact. If the keys have + * not yet been bound to a contact, {@link #getContactId()}} returns null. + */ +@Immutable +@NotNullByDefault +public class KeySet { + + private final KeySetId keySetId; + @Nullable + private final ContactId contactId; + private final TransportKeys transportKeys; + + public KeySet(KeySetId keySetId, @Nullable ContactId contactId, + TransportKeys transportKeys) { + this.keySetId = keySetId; + this.contactId = contactId; + this.transportKeys = transportKeys; + } + + public KeySetId getKeySetId() { + return keySetId; + } + + @Nullable + public ContactId getContactId() { + return contactId; + } + + public TransportKeys getTransportKeys() { + return transportKeys; + } + + @Override + public int hashCode() { + return keySetId.hashCode(); + } + + @Override + public boolean equals(Object o) { + return o instanceof KeySet && keySetId.equals(((KeySet) o).keySetId); + } +} diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySetId.java b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySetId.java new file mode 100644 index 0000000000000000000000000000000000000000..1f872e72a34d7ffb480423d11bbe1aa0db1bdf82 --- /dev/null +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySetId.java @@ -0,0 +1,36 @@ +package org.briarproject.bramble.api.transport; + +import org.briarproject.bramble.api.nullsafety.NotNullByDefault; + +import javax.annotation.concurrent.Immutable; + +/** + * Type-safe wrapper for an integer that uniquely identifies a set of transport + * keys within the scope of the local device. + * <p/> + * Key sets created on a given device must have increasing identifiers. + */ +@Immutable +@NotNullByDefault +public class KeySetId { + + private final int id; + + public KeySetId(int id) { + this.id = id; + } + + public int getInt() { + return id; + } + + @Override + public int hashCode() { + return id; + } + + @Override + public boolean equals(Object o) { + return o instanceof KeySetId && id == ((KeySetId) o).id; + } +} diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/transport/OutgoingKeys.java b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/OutgoingKeys.java index 202c46e6a0ec7abd40f4aa7d381f10cca6590b53..4214ffd9146d869ca04378f0b126ff7875b24f78 100644 --- a/bramble-api/src/main/java/org/briarproject/bramble/api/transport/OutgoingKeys.java +++ b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/OutgoingKeys.java @@ -10,18 +10,20 @@ public class OutgoingKeys { private final SecretKey tagKey, headerKey; private final long rotationPeriod, streamCounter; + private final boolean active; public OutgoingKeys(SecretKey tagKey, SecretKey headerKey, - long rotationPeriod) { - this(tagKey, headerKey, rotationPeriod, 0); + long rotationPeriod, boolean active) { + this(tagKey, headerKey, rotationPeriod, 0, active); } public OutgoingKeys(SecretKey tagKey, SecretKey headerKey, - long rotationPeriod, long streamCounter) { + long rotationPeriod, long streamCounter, boolean active) { this.tagKey = tagKey; this.headerKey = headerKey; this.rotationPeriod = rotationPeriod; this.streamCounter = streamCounter; + this.active = active; } public SecretKey getTagKey() { @@ -39,4 +41,8 @@ public class OutgoingKeys { public long getStreamCounter() { return streamCounter; } + + public boolean isActive() { + return active; + } } \ No newline at end of file diff --git a/bramble-core/src/main/java/org/briarproject/bramble/contact/ContactManagerImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/contact/ContactManagerImpl.java index 25e0681c963778bf649bef07e801a2a06474314a..025a6d38d079a873080ac5cd2acbd20ca2a013ae 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/contact/ContactManagerImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/contact/ContactManagerImpl.java @@ -50,7 +50,7 @@ class ContactManagerImpl implements ContactManager { @Override public ContactId addContact(Transaction txn, Author remote, AuthorId local, - SecretKey master,long timestamp, boolean alice, boolean verified, + SecretKey master, long timestamp, boolean alice, boolean verified, boolean active) throws DbException { ContactId c = db.addContact(txn, remote, local, verified, active); keyManager.addContact(txn, c, master, timestamp, alice); @@ -60,6 +60,16 @@ class ContactManagerImpl implements ContactManager { return c; } + @Override + public ContactId addContact(Transaction txn, Author remote, AuthorId local, + boolean verified, boolean active) throws DbException { + ContactId c = db.addContact(txn, remote, local, verified, active); + Contact contact = db.getContact(txn, c); + for (AddContactHook hook : addHooks) + hook.addingContact(txn, contact); + return c; + } + @Override public ContactId addContact(Author remote, AuthorId local, SecretKey master, long timestamp, boolean alice, boolean verified, boolean active) diff --git a/bramble-core/src/main/java/org/briarproject/bramble/crypto/TransportCryptoImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/crypto/TransportCryptoImpl.java index db35c9d5e3c99a126e4b8e623705b27402c7b6d9..2d4ffb7d31ea719854ee2a52c0157f8a032fff84 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/crypto/TransportCryptoImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/crypto/TransportCryptoImpl.java @@ -36,7 +36,8 @@ class TransportCryptoImpl implements TransportCrypto { @Override public TransportKeys deriveTransportKeys(TransportId t, - SecretKey master, long rotationPeriod, boolean alice) { + SecretKey master, long rotationPeriod, boolean alice, + boolean active) { // Keys for the previous period are derived from the master secret SecretKey inTagPrev = deriveTagKey(master, t, !alice); SecretKey inHeaderPrev = deriveHeaderKey(master, t, !alice); @@ -57,7 +58,7 @@ class TransportCryptoImpl implements TransportCrypto { IncomingKeys inNext = new IncomingKeys(inTagNext, inHeaderNext, rotationPeriod + 1); OutgoingKeys outCurr = new OutgoingKeys(outTagCurr, outHeaderCurr, - rotationPeriod); + rotationPeriod, active); // Collect and return the keys return new TransportKeys(t, inPrev, inCurr, inNext, outCurr); } @@ -71,6 +72,7 @@ class TransportCryptoImpl implements TransportCrypto { IncomingKeys inNext = k.getNextIncomingKeys(); OutgoingKeys outCurr = k.getCurrentOutgoingKeys(); long startPeriod = outCurr.getRotationPeriod(); + boolean active = outCurr.isActive(); // Rotate the keys for (long p = startPeriod + 1; p <= rotationPeriod; p++) { inPrev = inCurr; @@ -80,7 +82,7 @@ class TransportCryptoImpl implements TransportCrypto { inNext = new IncomingKeys(inNextTag, inNextHeader, p + 1); SecretKey outCurrTag = rotateKey(outCurr.getTagKey(), p); SecretKey outCurrHeader = rotateKey(outCurr.getHeaderKey(), p); - outCurr = new OutgoingKeys(outCurrTag, outCurrHeader, p); + outCurr = new OutgoingKeys(outCurrTag, outCurrHeader, p, active); } // Collect and return the keys return new TransportKeys(k.getTransportId(), inPrev, inCurr, inNext, diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java b/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java index bc3166d5622e605a1fdddcbf0858bc438b2216ce..dcf96ea4e116773c97ca99d7ff27b8d474ecd53a 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java @@ -21,6 +21,8 @@ import org.briarproject.bramble.api.sync.Message; import org.briarproject.bramble.api.sync.MessageId; import org.briarproject.bramble.api.sync.MessageStatus; import org.briarproject.bramble.api.sync.ValidationManager.State; +import org.briarproject.bramble.api.transport.KeySet; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.TransportKeys; import java.util.Collection; @@ -123,9 +125,16 @@ interface Database<T> { throws DbException; /** - * Stores transport keys for a newly added contact. + * Stores the given transport keys, optionally binding them to the given + * contact, and returns a key set ID. */ - void addTransportKeys(T txn, ContactId c, TransportKeys k) + KeySetId addTransportKeys(T txn, @Nullable ContactId c, TransportKeys k) + throws DbException; + + /** + * Binds the given keys for the given transport to the given contact. + */ + void bindTransportKeys(T txn, ContactId c, TransportId t, KeySetId k) throws DbException; /** @@ -486,15 +495,14 @@ interface Database<T> { * <p/> * Read-only. */ - Map<ContactId, TransportKeys> getTransportKeys(T txn, TransportId t) + Collection<KeySet> getTransportKeys(T txn, TransportId t) throws DbException; /** - * Increments the outgoing stream counter for the given contact and - * transport in the given rotation period. + * Increments the outgoing stream counter for the given transport keys. */ - void incrementStreamCounter(T txn, ContactId c, TransportId t, - long rotationPeriod) throws DbException; + void incrementStreamCounter(T txn, TransportId t, KeySetId k) + throws DbException; /** * Marks the given messages as not needing to be acknowledged to the @@ -584,6 +592,12 @@ interface Database<T> { */ void removeTransport(T txn, TransportId t) throws DbException; + /** + * Removes the given transport keys from the database. + */ + void removeTransportKeys(T txn, TransportId t, KeySetId k) + throws DbException; + /** * Resets the transmission count and expiry time of the given message with * respect to the given contact. @@ -619,12 +633,18 @@ interface Database<T> { void setMessageState(T txn, MessageId m, State state) throws DbException; /** - * Sets the reordering window for the given contact and transport in the + * Sets the reordering window for the given key set and transport in the * given rotation period. */ - void setReorderingWindow(T txn, ContactId c, TransportId t, + void setReorderingWindow(T txn, KeySetId k, TransportId t, long rotationPeriod, long base, byte[] bitmap) throws DbException; + /** + * Marks the given transport keys as usable for outgoing streams. + */ + void setTransportKeysActive(T txn, TransportId t, KeySetId k) + throws DbException; + /** * Updates the transmission count and expiry time of the given message * with respect to the given contact, using the latency of the transport @@ -636,6 +656,5 @@ interface Database<T> { /** * Stores the given transport keys, deleting any keys they have replaced. */ - void updateTransportKeys(T txn, Map<ContactId, TransportKeys> keys) - throws DbException; + void updateTransportKeys(T txn, Collection<KeySet> keys) throws DbException; } diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java index 013233f395ceac5ff0f65cf5a0044473238858c4..00f0bb62292c029505e98fd07b4496f81b2a6d42 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java @@ -51,15 +51,15 @@ import org.briarproject.bramble.api.sync.event.MessageToAckEvent; import org.briarproject.bramble.api.sync.event.MessageToRequestEvent; import org.briarproject.bramble.api.sync.event.MessagesAckedEvent; import org.briarproject.bramble.api.sync.event.MessagesSentEvent; +import org.briarproject.bramble.api.transport.KeySet; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.TransportKeys; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.logging.Logger; @@ -234,15 +234,27 @@ class DatabaseComponentImpl<T> implements DatabaseComponent { } @Override - public void addTransportKeys(Transaction transaction, ContactId c, - TransportKeys k) throws DbException { + public KeySetId addTransportKeys(Transaction transaction, + @Nullable ContactId c, TransportKeys k) throws DbException { if (transaction.isReadOnly()) throw new IllegalArgumentException(); T txn = unbox(transaction); - if (!db.containsContact(txn, c)) + if (c != null && !db.containsContact(txn, c)) throw new NoSuchContactException(); if (!db.containsTransport(txn, k.getTransportId())) throw new NoSuchTransportException(); - db.addTransportKeys(txn, c, k); + return db.addTransportKeys(txn, c, k); + } + + @Override + public void bindTransportKeys(Transaction transaction, ContactId c, + TransportId t, KeySetId k) throws DbException { + if (transaction.isReadOnly()) throw new IllegalArgumentException(); + T txn = unbox(transaction); + if (!db.containsContact(txn, c)) + throw new NoSuchContactException(); + if (!db.containsTransport(txn, t)) + throw new NoSuchTransportException(); + db.bindTransportKeys(txn, c, t, k); } @Override @@ -586,8 +598,8 @@ class DatabaseComponentImpl<T> implements DatabaseComponent { } @Override - public Map<ContactId, TransportKeys> getTransportKeys( - Transaction transaction, TransportId t) throws DbException { + public Collection<KeySet> getTransportKeys(Transaction transaction, + TransportId t) throws DbException { T txn = unbox(transaction); if (!db.containsTransport(txn, t)) throw new NoSuchTransportException(); @@ -595,15 +607,13 @@ class DatabaseComponentImpl<T> implements DatabaseComponent { } @Override - public void incrementStreamCounter(Transaction transaction, ContactId c, - TransportId t, long rotationPeriod) throws DbException { + public void incrementStreamCounter(Transaction transaction, TransportId t, + KeySetId k) throws DbException { if (transaction.isReadOnly()) throw new IllegalArgumentException(); T txn = unbox(transaction); - if (!db.containsContact(txn, c)) - throw new NoSuchContactException(); if (!db.containsTransport(txn, t)) throw new NoSuchTransportException(); - db.incrementStreamCounter(txn, c, t, rotationPeriod); + db.incrementStreamCounter(txn, t, k); } @Override @@ -779,6 +789,16 @@ class DatabaseComponentImpl<T> implements DatabaseComponent { db.removeTransport(txn, t); } + @Override + public void removeTransportKeys(Transaction transaction, + TransportId t, KeySetId k) throws DbException { + if (transaction.isReadOnly()) throw new IllegalArgumentException(); + T txn = unbox(transaction); + if (!db.containsTransport(txn, t)) + throw new NoSuchTransportException(); + db.removeTransportKeys(txn, t, k); + } + @Override public void setContactVerified(Transaction transaction, ContactId c) throws DbException { @@ -858,31 +878,35 @@ class DatabaseComponentImpl<T> implements DatabaseComponent { } @Override - public void setReorderingWindow(Transaction transaction, ContactId c, + public void setReorderingWindow(Transaction transaction, KeySetId k, TransportId t, long rotationPeriod, long base, byte[] bitmap) throws DbException { if (transaction.isReadOnly()) throw new IllegalArgumentException(); T txn = unbox(transaction); - if (!db.containsContact(txn, c)) - throw new NoSuchContactException(); if (!db.containsTransport(txn, t)) throw new NoSuchTransportException(); - db.setReorderingWindow(txn, c, t, rotationPeriod, base, bitmap); + db.setReorderingWindow(txn, k, t, rotationPeriod, base, bitmap); + } + + @Override + public void setTransportKeysActive(Transaction transaction, TransportId t, + KeySetId k) throws DbException { + if (transaction.isReadOnly()) throw new IllegalArgumentException(); + T txn = unbox(transaction); + if (!db.containsTransport(txn, t)) + throw new NoSuchTransportException(); + db.setTransportKeysActive(txn, t, k); } @Override public void updateTransportKeys(Transaction transaction, - Map<ContactId, TransportKeys> keys) throws DbException { + Collection<KeySet> keys) throws DbException { if (transaction.isReadOnly()) throw new IllegalArgumentException(); T txn = unbox(transaction); - Map<ContactId, TransportKeys> filtered = new HashMap<>(); - for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { - ContactId c = e.getKey(); - TransportKeys k = e.getValue(); - if (db.containsContact(txn, c) - && db.containsTransport(txn, k.getTransportId())) { - filtered.put(c, k); - } + Collection<KeySet> filtered = new ArrayList<>(); + for (KeySet ks : keys) { + TransportId t = ks.getTransportKeys().getTransportId(); + if (db.containsTransport(txn, t)) filtered.add(ks); } db.updateTransportKeys(txn, filtered); } diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java b/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java index 51fa9ae309c39b93421e37673745ce143e8e59f5..dd5874f6edcca3400a2922d3f1feb9e2db5f2b2f 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java @@ -25,6 +25,8 @@ import org.briarproject.bramble.api.sync.MessageStatus; import org.briarproject.bramble.api.sync.ValidationManager.State; import org.briarproject.bramble.api.system.Clock; import org.briarproject.bramble.api.transport.IncomingKeys; +import org.briarproject.bramble.api.transport.KeySet; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.TransportKeys; @@ -223,37 +225,44 @@ abstract class JdbcDatabase implements Database<Connection> { + " maxLatency INT NOT NULL," + " PRIMARY KEY (transportId))"; - private static final String CREATE_INCOMING_KEYS = - "CREATE TABLE incomingKeys" - + " (contactId INT NOT NULL," - + " transportId _STRING NOT NULL," + private static final String CREATE_OUTGOING_KEYS = + "CREATE TABLE outgoingKeys" + + " (transportId _STRING NOT NULL," + + " keySetId _COUNTER," + " rotationPeriod BIGINT NOT NULL," + + " contactId INT," // Null if keys are not bound + " tagKey _SECRET NOT NULL," + " headerKey _SECRET NOT NULL," - + " base BIGINT NOT NULL," - + " bitmap _BINARY NOT NULL," - + " PRIMARY KEY (contactId, transportId, rotationPeriod)," - + " FOREIGN KEY (contactId)" - + " REFERENCES contacts (contactId)" - + " ON DELETE CASCADE," + + " stream BIGINT NOT NULL," + + " active BOOLEAN NOT NULL," + + " PRIMARY KEY (transportId, keySetId)," + " FOREIGN KEY (transportId)" + " REFERENCES transports (transportId)" + + " ON DELETE CASCADE," + + " UNIQUE (keySetId)," + + " FOREIGN KEY (contactId)" + + " REFERENCES contacts (contactId)" + " ON DELETE CASCADE)"; - private static final String CREATE_OUTGOING_KEYS = - "CREATE TABLE outgoingKeys" - + " (contactId INT NOT NULL," - + " transportId _STRING NOT NULL," + private static final String CREATE_INCOMING_KEYS = + "CREATE TABLE incomingKeys" + + " (transportId _STRING NOT NULL," + + " keySetId INT NOT NULL," + " rotationPeriod BIGINT NOT NULL," + + " contactId INT," // Null if keys are not bound + " tagKey _SECRET NOT NULL," + " headerKey _SECRET NOT NULL," - + " stream BIGINT NOT NULL," - + " PRIMARY KEY (contactId, transportId)," - + " FOREIGN KEY (contactId)" - + " REFERENCES contacts (contactId)" - + " ON DELETE CASCADE," + + " base BIGINT NOT NULL," + + " bitmap _BINARY NOT NULL," + + " PRIMARY KEY (transportId, keySetId, rotationPeriod)," + " FOREIGN KEY (transportId)" + " REFERENCES transports (transportId)" + + " ON DELETE CASCADE," + + " FOREIGN KEY (keySetId)" + + " REFERENCES outgoingKeys (keySetId)" + + " ON DELETE CASCADE," + + " FOREIGN KEY (contactId)" + + " REFERENCES contacts (contactId)" + " ON DELETE CASCADE)"; private static final String INDEX_CONTACTS_BY_AUTHOR_ID = @@ -415,8 +424,8 @@ abstract class JdbcDatabase implements Database<Connection> { s.executeUpdate(insertTypeNames(CREATE_OFFERS)); s.executeUpdate(insertTypeNames(CREATE_STATUSES)); s.executeUpdate(insertTypeNames(CREATE_TRANSPORTS)); - s.executeUpdate(insertTypeNames(CREATE_INCOMING_KEYS)); s.executeUpdate(insertTypeNames(CREATE_OUTGOING_KEYS)); + s.executeUpdate(insertTypeNames(CREATE_INCOMING_KEYS)); s.close(); } catch (SQLException e) { tryToClose(s); @@ -865,60 +874,104 @@ abstract class JdbcDatabase implements Database<Connection> { } @Override - public void addTransportKeys(Connection txn, ContactId c, TransportKeys k) - throws DbException { + public KeySetId addTransportKeys(Connection txn, @Nullable ContactId c, + TransportKeys k) throws DbException { PreparedStatement ps = null; + ResultSet rs = null; try { - // Store the incoming keys - String sql = "INSERT INTO incomingKeys (contactId, transportId," - + " rotationPeriod, tagKey, headerKey, base, bitmap)" + // Store the outgoing keys + String sql = "INSERT INTO outgoingKeys (contactId, transportId," + + " rotationPeriod, tagKey, headerKey, stream, active)" + " VALUES (?, ?, ?, ?, ?, ?, ?)"; ps = txn.prepareStatement(sql); - ps.setInt(1, c.getInt()); + if (c == null) ps.setNull(1, INTEGER); + else ps.setInt(1, c.getInt()); ps.setString(2, k.getTransportId().getString()); + OutgoingKeys outCurr = k.getCurrentOutgoingKeys(); + ps.setLong(3, outCurr.getRotationPeriod()); + ps.setBytes(4, outCurr.getTagKey().getBytes()); + ps.setBytes(5, outCurr.getHeaderKey().getBytes()); + ps.setLong(6, outCurr.getStreamCounter()); + ps.setBoolean(7, outCurr.isActive()); + int affected = ps.executeUpdate(); + if (affected != 1) throw new DbStateException(); + ps.close(); + // Get the new (highest) key set ID + sql = "SELECT keySetId FROM outgoingKeys" + + " ORDER BY keySetId DESC LIMIT 1"; + ps = txn.prepareStatement(sql); + rs = ps.executeQuery(); + if (!rs.next()) throw new DbStateException(); + KeySetId keySetId = new KeySetId(rs.getInt(1)); + if (rs.next()) throw new DbStateException(); + rs.close(); + ps.close(); + // Store the incoming keys + sql = "INSERT INTO incomingKeys (keySetId, contactId, transportId," + + " rotationPeriod, tagKey, headerKey, base, bitmap)" + + " VALUES (?, ?, ?, ?, ?, ?, ?, ?)"; + ps = txn.prepareStatement(sql); + ps.setInt(1, keySetId.getInt()); + if (c == null) ps.setNull(2, INTEGER); + else ps.setInt(2, c.getInt()); + ps.setString(3, k.getTransportId().getString()); // Previous rotation period IncomingKeys inPrev = k.getPreviousIncomingKeys(); - ps.setLong(3, inPrev.getRotationPeriod()); - ps.setBytes(4, inPrev.getTagKey().getBytes()); - ps.setBytes(5, inPrev.getHeaderKey().getBytes()); - ps.setLong(6, inPrev.getWindowBase()); - ps.setBytes(7, inPrev.getWindowBitmap()); + ps.setLong(4, inPrev.getRotationPeriod()); + ps.setBytes(5, inPrev.getTagKey().getBytes()); + ps.setBytes(6, inPrev.getHeaderKey().getBytes()); + ps.setLong(7, inPrev.getWindowBase()); + ps.setBytes(8, inPrev.getWindowBitmap()); ps.addBatch(); // Current rotation period IncomingKeys inCurr = k.getCurrentIncomingKeys(); - ps.setLong(3, inCurr.getRotationPeriod()); - ps.setBytes(4, inCurr.getTagKey().getBytes()); - ps.setBytes(5, inCurr.getHeaderKey().getBytes()); - ps.setLong(6, inCurr.getWindowBase()); - ps.setBytes(7, inCurr.getWindowBitmap()); + ps.setLong(4, inCurr.getRotationPeriod()); + ps.setBytes(5, inCurr.getTagKey().getBytes()); + ps.setBytes(6, inCurr.getHeaderKey().getBytes()); + ps.setLong(7, inCurr.getWindowBase()); + ps.setBytes(8, inCurr.getWindowBitmap()); ps.addBatch(); // Next rotation period IncomingKeys inNext = k.getNextIncomingKeys(); - ps.setLong(3, inNext.getRotationPeriod()); - ps.setBytes(4, inNext.getTagKey().getBytes()); - ps.setBytes(5, inNext.getHeaderKey().getBytes()); - ps.setLong(6, inNext.getWindowBase()); - ps.setBytes(7, inNext.getWindowBitmap()); + ps.setLong(4, inNext.getRotationPeriod()); + ps.setBytes(5, inNext.getTagKey().getBytes()); + ps.setBytes(6, inNext.getHeaderKey().getBytes()); + ps.setLong(7, inNext.getWindowBase()); + ps.setBytes(8, inNext.getWindowBitmap()); ps.addBatch(); int[] batchAffected = ps.executeBatch(); if (batchAffected.length != 3) throw new DbStateException(); for (int rows : batchAffected) if (rows != 1) throw new DbStateException(); ps.close(); - // Store the outgoing keys - sql = "INSERT INTO outgoingKeys (contactId, transportId," - + " rotationPeriod, tagKey, headerKey, stream)" - + " VALUES (?, ?, ?, ?, ?, ?)"; + return keySetId; + } catch (SQLException e) { + tryToClose(rs); + tryToClose(ps); + throw new DbException(e); + } + } + + @Override + public void bindTransportKeys(Connection txn, ContactId c, TransportId t, + KeySetId k) throws DbException { + PreparedStatement ps = null; + try { + String sql = "UPDATE outgoingKeys SET contactId = ?" + + " WHERE keySetId = ?"; ps = txn.prepareStatement(sql); ps.setInt(1, c.getInt()); - ps.setString(2, k.getTransportId().getString()); - OutgoingKeys outCurr = k.getCurrentOutgoingKeys(); - ps.setLong(3, outCurr.getRotationPeriod()); - ps.setBytes(4, outCurr.getTagKey().getBytes()); - ps.setBytes(5, outCurr.getHeaderKey().getBytes()); - ps.setLong(6, outCurr.getStreamCounter()); + ps.setInt(2, k.getInt()); int affected = ps.executeUpdate(); - if (affected != 1) throw new DbStateException(); + if (affected < 0) throw new DbStateException(); + ps.close(); + sql = "UPDATE incomingKeys SET contactId = ?" + + " WHERE keySetId = ?"; + ps = txn.prepareStatement(sql); + ps.setInt(1, c.getInt()); + ps.setInt(2, k.getInt()); + affected = ps.executeUpdate(); + if (affected < 0) throw new DbStateException(); ps.close(); } catch (SQLException e) { tryToClose(ps); @@ -2078,8 +2131,8 @@ abstract class JdbcDatabase implements Database<Connection> { } @Override - public Map<ContactId, TransportKeys> getTransportKeys(Connection txn, - TransportId t) throws DbException { + public Collection<KeySet> getTransportKeys(Connection txn, TransportId t) + throws DbException { PreparedStatement ps = null; ResultSet rs = null; try { @@ -2088,7 +2141,7 @@ abstract class JdbcDatabase implements Database<Connection> { + " base, bitmap" + " FROM incomingKeys" + " WHERE transportId = ?" - + " ORDER BY contactId, rotationPeriod"; + + " ORDER BY keySetId, rotationPeriod"; ps = txn.prepareStatement(sql); ps.setString(1, t.getString()); rs = ps.executeQuery(); @@ -2105,29 +2158,34 @@ abstract class JdbcDatabase implements Database<Connection> { rs.close(); ps.close(); // Retrieve the outgoing keys in the same order - sql = "SELECT contactId, rotationPeriod, tagKey, headerKey, stream" + sql = "SELECT keySetId, contactId, rotationPeriod," + + " tagKey, headerKey, stream, active" + " FROM outgoingKeys" + " WHERE transportId = ?" - + " ORDER BY contactId, rotationPeriod"; + + " ORDER BY keySetId"; ps = txn.prepareStatement(sql); ps.setString(1, t.getString()); rs = ps.executeQuery(); - Map<ContactId, TransportKeys> keys = new HashMap<>(); + Collection<KeySet> keys = new ArrayList<>(); for (int i = 0; rs.next(); i++) { // There should be three times as many incoming keys if (inKeys.size() < (i + 1) * 3) throw new DbStateException(); - ContactId contactId = new ContactId(rs.getInt(1)); - long rotationPeriod = rs.getLong(2); - SecretKey tagKey = new SecretKey(rs.getBytes(3)); - SecretKey headerKey = new SecretKey(rs.getBytes(4)); - long streamCounter = rs.getLong(5); + KeySetId keySetId = new KeySetId(rs.getInt(1)); + ContactId contactId = new ContactId(rs.getInt(2)); + if (rs.wasNull()) contactId = null; + long rotationPeriod = rs.getLong(3); + SecretKey tagKey = new SecretKey(rs.getBytes(4)); + SecretKey headerKey = new SecretKey(rs.getBytes(5)); + long streamCounter = rs.getLong(6); + boolean active = rs.getBoolean(7); OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey, - rotationPeriod, streamCounter); + rotationPeriod, streamCounter, active); IncomingKeys inPrev = inKeys.get(i * 3); IncomingKeys inCurr = inKeys.get(i * 3 + 1); IncomingKeys inNext = inKeys.get(i * 3 + 2); - keys.put(contactId, new TransportKeys(t, inPrev, inCurr, - inNext, outCurr)); + TransportKeys transportKeys = new TransportKeys(t, inPrev, + inCurr, inNext, outCurr); + keys.add(new KeySet(keySetId, contactId, transportKeys)); } rs.close(); ps.close(); @@ -2140,17 +2198,15 @@ abstract class JdbcDatabase implements Database<Connection> { } @Override - public void incrementStreamCounter(Connection txn, ContactId c, - TransportId t, long rotationPeriod) throws DbException { + public void incrementStreamCounter(Connection txn, TransportId t, + KeySetId k) throws DbException { PreparedStatement ps = null; try { String sql = "UPDATE outgoingKeys SET stream = stream + 1" - + " WHERE contactId = ? AND transportId = ?" - + " AND rotationPeriod = ?"; + + " WHERE transportId = ? AND keySetId = ?"; ps = txn.prepareStatement(sql); - ps.setInt(1, c.getInt()); - ps.setString(2, t.getString()); - ps.setLong(3, rotationPeriod); + ps.setString(1, t.getString()); + ps.setInt(2, k.getInt()); int affected = ps.executeUpdate(); if (affected != 1) throw new DbStateException(); ps.close(); @@ -2626,6 +2682,27 @@ abstract class JdbcDatabase implements Database<Connection> { } } + @Override + public void removeTransportKeys(Connection txn, TransportId t, KeySetId k) + throws DbException { + PreparedStatement ps = null; + try { + // Delete any existing outgoing keys - this will also remove any + // incoming keys with the same key set ID + String sql = "DELETE FROM outgoingKeys" + + " WHERE transportId = ? AND keySetId = ?"; + ps = txn.prepareStatement(sql); + ps.setString(1, t.getString()); + ps.setInt(2, k.getInt()); + int affected = ps.executeUpdate(); + if (affected < 0) throw new DbStateException(); + ps.close(); + } catch (SQLException e) { + tryToClose(ps); + throw new DbException(e); + } + } + @Override public void resetExpiryTime(Connection txn, ContactId c, MessageId m) throws DbException { @@ -2791,18 +2868,18 @@ abstract class JdbcDatabase implements Database<Connection> { } @Override - public void setReorderingWindow(Connection txn, ContactId c, TransportId t, + public void setReorderingWindow(Connection txn, KeySetId k, TransportId t, long rotationPeriod, long base, byte[] bitmap) throws DbException { PreparedStatement ps = null; try { String sql = "UPDATE incomingKeys SET base = ?, bitmap = ?" - + " WHERE contactId = ? AND transportId = ?" + + " WHERE transportId = ? AND keySetId = ?" + " AND rotationPeriod = ?"; ps = txn.prepareStatement(sql); ps.setLong(1, base); ps.setBytes(2, bitmap); - ps.setInt(3, c.getInt()); - ps.setString(4, t.getString()); + ps.setString(3, t.getString()); + ps.setInt(4, k.getInt()); ps.setLong(5, rotationPeriod); int affected = ps.executeUpdate(); if (affected < 0 || affected > 1) throw new DbStateException(); @@ -2813,6 +2890,23 @@ abstract class JdbcDatabase implements Database<Connection> { } } + @Override + public void setTransportKeysActive(Connection txn, TransportId t, + KeySetId k) throws DbException { + PreparedStatement ps = null; + try { + String sql = "UPDATE outgoingKeys SET active = true" + + " WHERE transportId = ? AND keySetId = ?"; + ps = txn.prepareStatement(sql); + int affected = ps.executeUpdate(); + if (affected < 0 || affected > 1) throw new DbStateException(); + ps.close(); + } catch (SQLException e) { + tryToClose(ps); + throw new DbException(e); + } + } + @Override public void updateExpiryTime(Connection txn, ContactId c, MessageId m, int maxLatency) throws DbException { @@ -2848,45 +2942,12 @@ abstract class JdbcDatabase implements Database<Connection> { } @Override - public void updateTransportKeys(Connection txn, - Map<ContactId, TransportKeys> keys) throws DbException { - PreparedStatement ps = null; - try { - // Delete any existing incoming keys - String sql = "DELETE FROM incomingKeys" - + " WHERE contactId = ?" - + " AND transportId = ?"; - ps = txn.prepareStatement(sql); - for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { - ps.setInt(1, e.getKey().getInt()); - ps.setString(2, e.getValue().getTransportId().getString()); - ps.addBatch(); - } - int[] batchAffected = ps.executeBatch(); - if (batchAffected.length != keys.size()) - throw new DbStateException(); - ps.close(); - // Delete any existing outgoing keys - sql = "DELETE FROM outgoingKeys" - + " WHERE contactId = ?" - + " AND transportId = ?"; - ps = txn.prepareStatement(sql); - for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { - ps.setInt(1, e.getKey().getInt()); - ps.setString(2, e.getValue().getTransportId().getString()); - ps.addBatch(); - } - batchAffected = ps.executeBatch(); - if (batchAffected.length != keys.size()) - throw new DbStateException(); - ps.close(); - } catch (SQLException e) { - tryToClose(ps); - throw new DbException(e); - } - // Store the new keys - for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { - addTransportKeys(txn, e.getKey(), e.getValue()); + public void updateTransportKeys(Connection txn, Collection<KeySet> keys) + throws DbException { + for (KeySet ks : keys) { + TransportKeys k = ks.getTransportKeys(); + removeTransportKeys(txn, k.getTransportId(), ks.getKeySetId()); + addTransportKeys(txn, ks.getContactId(), k); } } } diff --git a/bramble-core/src/main/java/org/briarproject/bramble/transport/KeyManagerImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/transport/KeyManagerImpl.java index d0c6fd709d619911f6cc9862b99947df3eb82bfe..bbd5e5ec15c3e670db0217e8b5ccb63bc799182e 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/transport/KeyManagerImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/transport/KeyManagerImpl.java @@ -19,6 +19,7 @@ import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.duplex.DuplexPluginFactory; import org.briarproject.bramble.api.plugin.simplex.SimplexPluginFactory; import org.briarproject.bramble.api.transport.KeyManager; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.StreamContext; import java.util.HashMap; @@ -104,6 +105,67 @@ class KeyManagerImpl implements KeyManager, Service, EventListener { m.addContact(txn, c, master, timestamp, alice); } + @Override + public Map<TransportId, KeySetId> addUnboundKeys(Transaction txn, + SecretKey master, long timestamp, boolean alice) + throws DbException { + Map<TransportId, KeySetId> ids = new HashMap<>(); + for (Entry<TransportId, TransportKeyManager> e : managers.entrySet()) { + TransportId t = e.getKey(); + TransportKeyManager m = e.getValue(); + ids.put(t, m.addUnboundKeys(txn, master, timestamp, alice)); + } + return ids; + } + + @Override + public void bindKeys(Transaction txn, ContactId c, + Map<TransportId, KeySetId> keys) throws DbException { + for (Entry<TransportId, KeySetId> e : keys.entrySet()) { + TransportId t = e.getKey(); + TransportKeyManager m = managers.get(t); + if (m == null) { + if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t); + } else { + m.bindKeys(txn, c, e.getValue()); + } + } + } + + @Override + public void activateKeys(Transaction txn, Map<TransportId, KeySetId> keys) + throws DbException { + for (Entry<TransportId, KeySetId> e : keys.entrySet()) { + TransportId t = e.getKey(); + TransportKeyManager m = managers.get(t); + if (m == null) { + if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t); + } else { + m.activateKeys(txn, e.getValue()); + } + } + } + + @Override + public void removeKeys(Transaction txn, Map<TransportId, KeySetId> keys) + throws DbException { + for (Entry<TransportId, KeySetId> e : keys.entrySet()) { + TransportId t = e.getKey(); + TransportKeyManager m = managers.get(t); + if (m == null) { + if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t); + } else { + m.removeKeys(txn, e.getValue()); + } + } + } + + @Override + public boolean canSendOutgoingStreams(ContactId c, TransportId t) { + TransportKeyManager m = managers.get(t); + return m == null ? false : m.canSendOutgoingStreams(c); + } + @Override public StreamContext getStreamContext(ContactId c, TransportId t) throws DbException { @@ -114,7 +176,7 @@ class KeyManagerImpl implements KeyManager, Service, EventListener { if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t); return null; } - StreamContext ctx = null; + StreamContext ctx; Transaction txn = db.startTransaction(false); try { ctx = m.getStreamContext(txn, c); @@ -133,7 +195,7 @@ class KeyManagerImpl implements KeyManager, Service, EventListener { if (LOG.isLoggable(INFO)) LOG.info("No key manager for " + t); return null; } - StreamContext ctx = null; + StreamContext ctx; Transaction txn = db.startTransaction(false); try { ctx = m.getStreamContext(txn, tag); diff --git a/bramble-core/src/main/java/org/briarproject/bramble/transport/MutableKeySet.java b/bramble-core/src/main/java/org/briarproject/bramble/transport/MutableKeySet.java new file mode 100644 index 0000000000000000000000000000000000000000..b55c5aef4f8551741c52e6a48d8630aebf86b29e --- /dev/null +++ b/bramble-core/src/main/java/org/briarproject/bramble/transport/MutableKeySet.java @@ -0,0 +1,34 @@ +package org.briarproject.bramble.transport; + +import org.briarproject.bramble.api.contact.ContactId; +import org.briarproject.bramble.api.transport.KeySetId; + +import javax.annotation.Nullable; + +public class MutableKeySet { + + private final KeySetId keySetId; + @Nullable + private final ContactId contactId; + private final MutableTransportKeys transportKeys; + + public MutableKeySet(KeySetId keySetId, @Nullable ContactId contactId, + MutableTransportKeys transportKeys) { + this.keySetId = keySetId; + this.contactId = contactId; + this.transportKeys = transportKeys; + } + + public KeySetId getKeySetId() { + return keySetId; + } + + @Nullable + public ContactId getContactId() { + return contactId; + } + + public MutableTransportKeys getTransportKeys() { + return transportKeys; + } +} diff --git a/bramble-core/src/main/java/org/briarproject/bramble/transport/MutableOutgoingKeys.java b/bramble-core/src/main/java/org/briarproject/bramble/transport/MutableOutgoingKeys.java index aaafec13bd192a0600b1d73170039432b102732b..c195f445cd2e151e0d408b7c9e2fdc15310bc229 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/transport/MutableOutgoingKeys.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/transport/MutableOutgoingKeys.java @@ -13,17 +13,19 @@ class MutableOutgoingKeys { private final SecretKey tagKey, headerKey; private final long rotationPeriod; private long streamCounter; + private boolean active; MutableOutgoingKeys(OutgoingKeys out) { tagKey = out.getTagKey(); headerKey = out.getHeaderKey(); rotationPeriod = out.getRotationPeriod(); streamCounter = out.getStreamCounter(); + active = out.isActive(); } OutgoingKeys snapshot() { return new OutgoingKeys(tagKey, headerKey, rotationPeriod, - streamCounter); + streamCounter, active); } SecretKey getTagKey() { @@ -45,4 +47,12 @@ class MutableOutgoingKeys { void incrementStreamCounter() { streamCounter++; } + + boolean isActive() { + return active; + } + + void activate() { + active = true; + } } diff --git a/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManager.java b/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManager.java index 6aa6d360fe2b24af7a643167347c955f75e90bdb..5ca159a4259640317a80e24953948f27ce0c38e3 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManager.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManager.java @@ -5,6 +5,7 @@ import org.briarproject.bramble.api.crypto.SecretKey; import org.briarproject.bramble.api.db.DbException; import org.briarproject.bramble.api.db.Transaction; import org.briarproject.bramble.api.nullsafety.NotNullByDefault; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.StreamContext; import javax.annotation.Nullable; @@ -17,8 +18,19 @@ interface TransportKeyManager { void addContact(Transaction txn, ContactId c, SecretKey master, long timestamp, boolean alice) throws DbException; + KeySetId addUnboundKeys(Transaction txn, SecretKey master, long timestamp, + boolean alice) throws DbException; + + void bindKeys(Transaction txn, ContactId c, KeySetId k) throws DbException; + + void activateKeys(Transaction txn, KeySetId k) throws DbException; + + void removeKeys(Transaction txn, KeySetId k) throws DbException; + void removeContact(ContactId c); + boolean canSendOutgoingStreams(ContactId c); + @Nullable StreamContext getStreamContext(Transaction txn, ContactId c) throws DbException; diff --git a/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java index 60b48427ff18f07703a147f827ae39639c6f629e..63bcdbc778a5cd6bb6ed7e48c96d6ee3737dc874 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java @@ -11,19 +11,24 @@ import org.briarproject.bramble.api.nullsafety.NotNullByDefault; import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.system.Clock; import org.briarproject.bramble.api.system.Scheduler; +import org.briarproject.bramble.api.transport.KeySet; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.StreamContext; import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.transport.ReorderingWindow.Change; +import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.Iterator; import java.util.Map; -import java.util.Map.Entry; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Logger; +import javax.annotation.Nullable; import javax.annotation.concurrent.ThreadSafe; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -47,12 +52,13 @@ class TransportKeyManagerImpl implements TransportKeyManager { private final Clock clock; private final TransportId transportId; private final long rotationPeriodLength; - private final ReentrantLock lock; + private final AtomicBoolean used = new AtomicBoolean(false); + private final ReentrantLock lock = new ReentrantLock(); // The following are locking: lock - private final Map<Bytes, TagContext> inContexts; - private final Map<ContactId, MutableOutgoingKeys> outContexts; - private final Map<ContactId, MutableTransportKeys> keys; + private final Map<KeySetId, MutableKeySet> keys = new HashMap<>(); + private final Map<Bytes, TagContext> inContexts = new HashMap<>(); + private final Map<ContactId, MutableKeySet> outContexts = new HashMap<>(); TransportKeyManagerImpl(DatabaseComponent db, TransportCrypto transportCrypto, Executor dbExecutor, @@ -65,20 +71,16 @@ class TransportKeyManagerImpl implements TransportKeyManager { this.clock = clock; this.transportId = transportId; rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE; - lock = new ReentrantLock(); - inContexts = new HashMap<>(); - outContexts = new HashMap<>(); - keys = new HashMap<>(); } @Override public void start(Transaction txn) throws DbException { + if (used.getAndSet(true)) throw new IllegalStateException(); long now = clock.currentTimeMillis(); lock.lock(); try { // Load the transport keys from the DB - Map<ContactId, TransportKeys> loaded = - db.getTransportKeys(txn, transportId); + Collection<KeySet> loaded = db.getTransportKeys(txn, transportId); // Rotate the keys to the current rotation period RotationResult rotationResult = rotateKeys(loaded, now); // Initialise mutable state for all contacts @@ -93,41 +95,48 @@ class TransportKeyManagerImpl implements TransportKeyManager { scheduleKeyRotation(now); } - private RotationResult rotateKeys(Map<ContactId, TransportKeys> keys, - long now) { + private RotationResult rotateKeys(Collection<KeySet> keys, long now) { RotationResult rotationResult = new RotationResult(); long rotationPeriod = now / rotationPeriodLength; - for (Entry<ContactId, TransportKeys> e : keys.entrySet()) { - ContactId c = e.getKey(); - TransportKeys k = e.getValue(); + for (KeySet ks : keys) { + TransportKeys k = ks.getTransportKeys(); TransportKeys k1 = transportCrypto.rotateTransportKeys(k, rotationPeriod); + KeySet ks1 = new KeySet(ks.getKeySetId(), ks.getContactId(), k1); if (k1.getRotationPeriod() > k.getRotationPeriod()) - rotationResult.rotated.put(c, k1); - rotationResult.current.put(c, k1); + rotationResult.rotated.add(ks1); + rotationResult.current.add(ks1); } return rotationResult; } // Locking: lock - private void addKeys(Map<ContactId, TransportKeys> m) { - for (Entry<ContactId, TransportKeys> e : m.entrySet()) - addKeys(e.getKey(), new MutableTransportKeys(e.getValue())); + private void addKeys(Collection<KeySet> keys) { + for (KeySet ks : keys) { + addKeys(ks.getKeySetId(), ks.getContactId(), + new MutableTransportKeys(ks.getTransportKeys())); + } } // Locking: lock - private void addKeys(ContactId c, MutableTransportKeys m) { - encodeTags(c, m.getPreviousIncomingKeys()); - encodeTags(c, m.getCurrentIncomingKeys()); - encodeTags(c, m.getNextIncomingKeys()); - outContexts.put(c, m.getCurrentOutgoingKeys()); - keys.put(c, m); + private void addKeys(KeySetId keySetId, @Nullable ContactId contactId, + MutableTransportKeys m) { + MutableKeySet ks = new MutableKeySet(keySetId, contactId, m); + keys.put(keySetId, ks); + if (contactId != null) { + encodeTags(keySetId, contactId, m.getPreviousIncomingKeys()); + encodeTags(keySetId, contactId, m.getCurrentIncomingKeys()); + encodeTags(keySetId, contactId, m.getNextIncomingKeys()); + considerReplacingOutgoingKeys(ks); + } } // Locking: lock - private void encodeTags(ContactId c, MutableIncomingKeys inKeys) { + private void encodeTags(KeySetId keySetId, ContactId contactId, + MutableIncomingKeys inKeys) { for (long streamNumber : inKeys.getWindow().getUnseen()) { - TagContext tagCtx = new TagContext(c, inKeys, streamNumber); + TagContext tagCtx = + new TagContext(keySetId, contactId, inKeys, streamNumber); byte[] tag = new byte[TAG_LENGTH]; transportCrypto.encodeTag(tag, inKeys.getTagKey(), PROTOCOL_VERSION, streamNumber); @@ -135,6 +144,17 @@ class TransportKeyManagerImpl implements TransportKeyManager { } } + // Locking: lock + private void considerReplacingOutgoingKeys(MutableKeySet ks) { + // Use the active outgoing keys with the highest key set ID + if (ks.getTransportKeys().getCurrentOutgoingKeys().isActive()) { + MutableKeySet old = outContexts.get(ks.getContactId()); + if (old == null || + old.getKeySetId().getInt() < ks.getKeySetId().getInt()) + outContexts.put(ks.getContactId(), ks); + } + } + private void scheduleKeyRotation(long now) { long delay = rotationPeriodLength - now % rotationPeriodLength; scheduler.schedule((Runnable) this::rotateKeys, delay, MILLISECONDS); @@ -159,20 +179,82 @@ class TransportKeyManagerImpl implements TransportKeyManager { @Override public void addContact(Transaction txn, ContactId c, SecretKey master, long timestamp, boolean alice) throws DbException { + deriveAndAddKeys(txn, c, master, timestamp, alice, true); + } + + @Override + public KeySetId addUnboundKeys(Transaction txn, SecretKey master, + long timestamp, boolean alice) throws DbException { + return deriveAndAddKeys(txn, null, master, timestamp, alice, false); + } + + private KeySetId deriveAndAddKeys(Transaction txn, @Nullable ContactId c, + SecretKey master, long timestamp, boolean alice, boolean active) + throws DbException { lock.lock(); try { // Work out what rotation period the timestamp belongs to long rotationPeriod = timestamp / rotationPeriodLength; // Derive the transport keys TransportKeys k = transportCrypto.deriveTransportKeys(transportId, - master, rotationPeriod, alice); + master, rotationPeriod, alice, active); // Rotate the keys to the current rotation period if necessary rotationPeriod = clock.currentTimeMillis() / rotationPeriodLength; k = transportCrypto.rotateTransportKeys(k, rotationPeriod); - // Initialise mutable state for the contact - addKeys(c, new MutableTransportKeys(k)); // Write the keys back to the DB - db.addTransportKeys(txn, c, k); + KeySetId keySetId = db.addTransportKeys(txn, c, k); + // Initialise mutable state for the contact + addKeys(keySetId, c, new MutableTransportKeys(k)); + return keySetId; + } finally { + lock.unlock(); + } + } + + @Override + public void bindKeys(Transaction txn, ContactId c, KeySetId k) + throws DbException { + lock.lock(); + try { + MutableKeySet ks = keys.get(k); + if (ks == null) throw new IllegalArgumentException(); + // Check that the keys haven't already been bound + if (ks.getContactId() != null) throw new IllegalArgumentException(); + MutableTransportKeys m = ks.getTransportKeys(); + addKeys(k, c, m); + db.bindTransportKeys(txn, c, m.getTransportId(), k); + } finally { + lock.unlock(); + } + } + + @Override + public void activateKeys(Transaction txn, KeySetId k) throws DbException { + lock.lock(); + try { + MutableKeySet ks = keys.get(k); + if (ks == null) throw new IllegalArgumentException(); + // Check that the keys have been bound + if (ks.getContactId() == null) throw new IllegalArgumentException(); + MutableTransportKeys m = ks.getTransportKeys(); + m.getCurrentOutgoingKeys().activate(); + considerReplacingOutgoingKeys(ks); + db.setTransportKeysActive(txn, m.getTransportId(), k); + } finally { + lock.unlock(); + } + } + + @Override + public void removeKeys(Transaction txn, KeySetId k) throws DbException { + lock.lock(); + try { + MutableKeySet ks = keys.remove(k); + if (ks == null) throw new IllegalArgumentException(); + // Check that the keys haven't been bound + if (ks.getContactId() != null) throw new IllegalArgumentException(); + TransportId t = ks.getTransportKeys().getTransportId(); + db.removeTransportKeys(txn, t, k); } finally { lock.unlock(); } @@ -183,12 +265,29 @@ class TransportKeyManagerImpl implements TransportKeyManager { lock.lock(); try { // Remove mutable state for the contact - Iterator<Entry<Bytes, TagContext>> it = - inContexts.entrySet().iterator(); - while (it.hasNext()) - if (it.next().getValue().contactId.equals(c)) it.remove(); + Iterator<TagContext> it = inContexts.values().iterator(); + while (it.hasNext()) if (it.next().contactId.equals(c)) it.remove(); outContexts.remove(c); - keys.remove(c); + Iterator<MutableKeySet> it1 = keys.values().iterator(); + while (it1.hasNext()) { + ContactId c1 = it1.next().getContactId(); + if (c1 != null && c1.equals(c)) it1.remove(); + } + } finally { + lock.unlock(); + } + } + + @Override + public boolean canSendOutgoingStreams(ContactId c) { + lock.lock(); + try { + MutableKeySet ks = outContexts.get(c); + if (ks == null) return false; + MutableOutgoingKeys outKeys = + ks.getTransportKeys().getCurrentOutgoingKeys(); + if (!outKeys.isActive()) throw new AssertionError(); + return outKeys.getStreamCounter() <= MAX_32_BIT_UNSIGNED; } finally { lock.unlock(); } @@ -200,8 +299,11 @@ class TransportKeyManagerImpl implements TransportKeyManager { lock.lock(); try { // Look up the outgoing keys for the contact - MutableOutgoingKeys outKeys = outContexts.get(c); - if (outKeys == null) return null; + MutableKeySet ks = outContexts.get(c); + if (ks == null) return null; + MutableOutgoingKeys outKeys = + ks.getTransportKeys().getCurrentOutgoingKeys(); + if (!outKeys.isActive()) throw new AssertionError(); if (outKeys.getStreamCounter() > MAX_32_BIT_UNSIGNED) return null; // Create a stream context StreamContext ctx = new StreamContext(c, transportId, @@ -209,8 +311,7 @@ class TransportKeyManagerImpl implements TransportKeyManager { outKeys.getStreamCounter()); // Increment the stream counter and write it back to the DB outKeys.incrementStreamCounter(); - db.incrementStreamCounter(txn, c, transportId, - outKeys.getRotationPeriod()); + db.incrementStreamCounter(txn, transportId, ks.getKeySetId()); return ctx; } finally { lock.unlock(); @@ -238,8 +339,9 @@ class TransportKeyManagerImpl implements TransportKeyManager { byte[] addTag = new byte[TAG_LENGTH]; transportCrypto.encodeTag(addTag, inKeys.getTagKey(), PROTOCOL_VERSION, streamNumber); - inContexts.put(new Bytes(addTag), new TagContext( - tagCtx.contactId, inKeys, streamNumber)); + TagContext tagCtx1 = new TagContext(tagCtx.keySetId, + tagCtx.contactId, inKeys, streamNumber); + inContexts.put(new Bytes(addTag), tagCtx1); } // Remove tags for any stream numbers removed from the window for (long streamNumber : change.getRemoved()) { @@ -250,9 +352,19 @@ class TransportKeyManagerImpl implements TransportKeyManager { inContexts.remove(new Bytes(removeTag)); } // Write the window back to the DB - db.setReorderingWindow(txn, tagCtx.contactId, transportId, + db.setReorderingWindow(txn, tagCtx.keySetId, transportId, inKeys.getRotationPeriod(), window.getBase(), window.getBitmap()); + // If the outgoing keys are inactive, activate them + MutableKeySet ks = keys.get(tagCtx.keySetId); + MutableOutgoingKeys outKeys = + ks.getTransportKeys().getCurrentOutgoingKeys(); + if (!outKeys.isActive()) { + LOG.info("Activating outgoing keys"); + outKeys.activate(); + considerReplacingOutgoingKeys(ks); + db.setTransportKeysActive(txn, transportId, tagCtx.keySetId); + } return ctx; } finally { lock.unlock(); @@ -264,9 +376,11 @@ class TransportKeyManagerImpl implements TransportKeyManager { lock.lock(); try { // Rotate the keys to the current rotation period - Map<ContactId, TransportKeys> snapshot = new HashMap<>(); - for (Entry<ContactId, MutableTransportKeys> e : keys.entrySet()) - snapshot.put(e.getKey(), e.getValue().snapshot()); + Collection<KeySet> snapshot = new ArrayList<>(keys.size()); + for (MutableKeySet ks : keys.values()) { + snapshot.add(new KeySet(ks.getKeySetId(), ks.getContactId(), + ks.getTransportKeys().snapshot())); + } RotationResult rotationResult = rotateKeys(snapshot, now); // Rebuild the mutable state for all contacts inContexts.clear(); @@ -285,12 +399,14 @@ class TransportKeyManagerImpl implements TransportKeyManager { private static class TagContext { + private final KeySetId keySetId; private final ContactId contactId; private final MutableIncomingKeys inKeys; private final long streamNumber; - private TagContext(ContactId contactId, MutableIncomingKeys inKeys, - long streamNumber) { + private TagContext(KeySetId keySetId, ContactId contactId, + MutableIncomingKeys inKeys, long streamNumber) { + this.keySetId = keySetId; this.contactId = contactId; this.inKeys = inKeys; this.streamNumber = streamNumber; @@ -299,11 +415,7 @@ class TransportKeyManagerImpl implements TransportKeyManager { private static class RotationResult { - private final Map<ContactId, TransportKeys> current, rotated; - - private RotationResult() { - current = new HashMap<>(); - rotated = new HashMap<>(); - } + private final Collection<KeySet> current = new ArrayList<>(); + private final Collection<KeySet> rotated = new ArrayList<>(); } } diff --git a/bramble-core/src/test/java/org/briarproject/bramble/crypto/KeyDerivationTest.java b/bramble-core/src/test/java/org/briarproject/bramble/crypto/KeyDerivationTest.java index 81f73f8d28f410a5c434d4d107815c416ff59d37..dc51966ffb3e1c15acc1fa66b0657b057cbe9624 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/crypto/KeyDerivationTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/crypto/KeyDerivationTest.java @@ -33,7 +33,7 @@ public class KeyDerivationTest extends BrambleTestCase { @Test public void testKeysAreDistinct() { TransportKeys k = transportCrypto.deriveTransportKeys(transportId, - master, 123, true); + master, 123, true, true); assertAllDifferent(k); } @@ -41,9 +41,9 @@ public class KeyDerivationTest extends BrambleTestCase { public void testCurrentKeysMatchCurrentKeysOfContact() { // Start in rotation period 123 TransportKeys kA = transportCrypto.deriveTransportKeys(transportId, - master, 123, true); + master, 123, true, true); TransportKeys kB = transportCrypto.deriveTransportKeys(transportId, - master, 123, false); + master, 123, false, true); // Alice's incoming keys should equal Bob's outgoing keys assertArrayEquals(kA.getCurrentIncomingKeys().getTagKey().getBytes(), kB.getCurrentOutgoingKeys().getTagKey().getBytes()); @@ -73,9 +73,9 @@ public class KeyDerivationTest extends BrambleTestCase { public void testPreviousKeysMatchPreviousKeysOfContact() { // Start in rotation period 123 TransportKeys kA = transportCrypto.deriveTransportKeys(transportId, - master, 123, true); + master, 123, true, true); TransportKeys kB = transportCrypto.deriveTransportKeys(transportId, - master, 123, false); + master, 123, false, true); // Compare Alice's previous keys in period 456 with Bob's current keys // in period 455 kA = transportCrypto.rotateTransportKeys(kA, 456); @@ -100,9 +100,9 @@ public class KeyDerivationTest extends BrambleTestCase { public void testNextKeysMatchNextKeysOfContact() { // Start in rotation period 123 TransportKeys kA = transportCrypto.deriveTransportKeys(transportId, - master, 123, true); + master, 123, true, true); TransportKeys kB = transportCrypto.deriveTransportKeys(transportId, - master, 123, false); + master, 123, false, true); // Compare Alice's current keys in period 456 with Bob's next keys in // period 455 kA = transportCrypto.rotateTransportKeys(kA, 456); @@ -127,9 +127,9 @@ public class KeyDerivationTest extends BrambleTestCase { SecretKey master1 = getSecretKey(); assertFalse(Arrays.equals(master.getBytes(), master1.getBytes())); TransportKeys k = transportCrypto.deriveTransportKeys(transportId, - master, 123, true); + master, 123, true, true); TransportKeys k1 = transportCrypto.deriveTransportKeys(transportId, - master1, 123, true); + master1, 123, true, true); assertAllDifferent(k, k1); } @@ -138,9 +138,9 @@ public class KeyDerivationTest extends BrambleTestCase { TransportId transportId1 = new TransportId("id1"); assertFalse(transportId.getString().equals(transportId1.getString())); TransportKeys k = transportCrypto.deriveTransportKeys(transportId, - master, 123, true); + master, 123, true, true); TransportKeys k1 = transportCrypto.deriveTransportKeys(transportId1, - master, 123, true); + master, 123, true, true); assertAllDifferent(k, k1); } diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java index 0ed4574618fc17e8eac6f0dcb59cb395f206e78d..f0149bb248948d6bb51d7bd1693a741407d464cc 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java @@ -44,6 +44,8 @@ import org.briarproject.bramble.api.sync.event.MessageToRequestEvent; import org.briarproject.bramble.api.sync.event.MessagesAckedEvent; import org.briarproject.bramble.api.sync.event.MessagesSentEvent; import org.briarproject.bramble.api.transport.IncomingKeys; +import org.briarproject.bramble.api.transport.KeySet; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.test.BrambleMockTestCase; @@ -55,12 +57,10 @@ import org.junit.Test; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.Map; import java.util.concurrent.atomic.AtomicReference; import static java.util.Collections.emptyMap; import static java.util.Collections.singletonList; -import static java.util.Collections.singletonMap; import static org.briarproject.bramble.api.sync.Group.Visibility.INVISIBLE; import static org.briarproject.bramble.api.sync.Group.Visibility.SHARED; import static org.briarproject.bramble.api.sync.Group.Visibility.VISIBLE; @@ -71,6 +71,7 @@ import static org.briarproject.bramble.api.transport.TransportConstants.REORDERI import static org.briarproject.bramble.db.DatabaseConstants.MAX_OFFERED_MESSAGES; import static org.briarproject.bramble.test.TestUtils.getAuthor; import static org.briarproject.bramble.test.TestUtils.getLocalAuthor; +import static org.briarproject.bramble.test.TestUtils.getSecretKey; import static org.briarproject.bramble.util.StringUtils.getRandomString; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -100,6 +101,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { private final int maxLatency; private final ContactId contactId; private final Contact contact; + private final KeySetId keySetId; public DatabaseComponentImplTest() { clientId = new ClientId(getRandomString(123)); @@ -121,6 +123,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { contactId = new ContactId(234); contact = new Contact(contactId, author, localAuthor.getId(), true, true); + keySetId = new KeySetId(345); } private DatabaseComponent createDatabaseComponent(Database<Object> database, @@ -282,11 +285,11 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { throws Exception { context.checking(new Expectations() {{ // Check whether the contact is in the DB (which it's not) - exactly(18).of(database).startTransaction(); + exactly(17).of(database).startTransaction(); will(returnValue(txn)); - exactly(18).of(database).containsContact(txn, contactId); + exactly(17).of(database).containsContact(txn, contactId); will(returnValue(false)); - exactly(18).of(database).abortTransaction(txn); + exactly(17).of(database).abortTransaction(txn); }}); DatabaseComponent db = createDatabaseComponent(database, eventBus, shutdown); @@ -303,7 +306,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { transaction = db.startTransaction(false); try { - db.generateAck(transaction, contactId, 123); + db.bindTransportKeys(transaction, contactId, transportId, keySetId); fail(); } catch (NoSuchContactException expected) { // Expected @@ -313,7 +316,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { transaction = db.startTransaction(false); try { - db.generateBatch(transaction, contactId, 123, 456); + db.generateAck(transaction, contactId, 123); fail(); } catch (NoSuchContactException expected) { // Expected @@ -323,7 +326,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { transaction = db.startTransaction(false); try { - db.generateOffer(transaction, contactId, 123, 456); + db.generateBatch(transaction, contactId, 123, 456); fail(); } catch (NoSuchContactException expected) { // Expected @@ -333,7 +336,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { transaction = db.startTransaction(false); try { - db.generateRequest(transaction, contactId, 123); + db.generateOffer(transaction, contactId, 123, 456); fail(); } catch (NoSuchContactException expected) { // Expected @@ -343,7 +346,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { transaction = db.startTransaction(false); try { - db.getContact(transaction, contactId); + db.generateRequest(transaction, contactId, 123); fail(); } catch (NoSuchContactException expected) { // Expected @@ -353,7 +356,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { transaction = db.startTransaction(false); try { - db.getMessageStatus(transaction, contactId, groupId); + db.getContact(transaction, contactId); fail(); } catch (NoSuchContactException expected) { // Expected @@ -363,7 +366,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { transaction = db.startTransaction(false); try { - db.getMessageStatus(transaction, contactId, messageId); + db.getMessageStatus(transaction, contactId, groupId); fail(); } catch (NoSuchContactException expected) { // Expected @@ -373,7 +376,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { transaction = db.startTransaction(false); try { - db.incrementStreamCounter(transaction, contactId, transportId, 0); + db.getMessageStatus(transaction, contactId, messageId); fail(); } catch (NoSuchContactException expected) { // Expected @@ -454,17 +457,6 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { db.endTransaction(transaction); } - transaction = db.startTransaction(false); - try { - db.setReorderingWindow(transaction, contactId, transportId, 0, 0, - new byte[REORDERING_WINDOW_SIZE / 8]); - fail(); - } catch (NoSuchContactException expected) { - // Expected - } finally { - db.endTransaction(transaction); - } - transaction = db.startTransaction(false); try { db.setGroupVisibility(transaction, contactId, groupId, SHARED); @@ -777,13 +769,13 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { // endTransaction() oneOf(database).commitTransaction(txn); // Check whether the transport is in the DB (which it's not) - exactly(4).of(database).startTransaction(); + exactly(6).of(database).startTransaction(); will(returnValue(txn)); - exactly(2).of(database).containsContact(txn, contactId); + oneOf(database).containsContact(txn, contactId); will(returnValue(true)); - exactly(4).of(database).containsTransport(txn, transportId); + exactly(6).of(database).containsTransport(txn, transportId); will(returnValue(false)); - exactly(4).of(database).abortTransaction(txn); + exactly(6).of(database).abortTransaction(txn); }}); DatabaseComponent db = createDatabaseComponent(database, eventBus, shutdown); @@ -798,6 +790,16 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { db.endTransaction(transaction); } + transaction = db.startTransaction(false); + try { + db.bindTransportKeys(transaction, contactId, transportId, keySetId); + fail(); + } catch (NoSuchTransportException expected) { + // Expected + } finally { + db.endTransaction(transaction); + } + transaction = db.startTransaction(false); try { db.getTransportKeys(transaction, transportId); @@ -810,7 +812,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { transaction = db.startTransaction(false); try { - db.incrementStreamCounter(transaction, contactId, transportId, 0); + db.incrementStreamCounter(transaction, transportId, keySetId); fail(); } catch (NoSuchTransportException expected) { // Expected @@ -830,7 +832,17 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { transaction = db.startTransaction(false); try { - db.setReorderingWindow(transaction, contactId, transportId, 0, 0, + db.removeTransportKeys(transaction, transportId, keySetId); + fail(); + } catch (NoSuchTransportException expected) { + // Expected + } finally { + db.endTransaction(transaction); + } + + transaction = db.startTransaction(false); + try { + db.setReorderingWindow(transaction, keySetId, transportId, 0, 0, new byte[REORDERING_WINDOW_SIZE / 8]); fail(); } catch (NoSuchTransportException expected) { @@ -1303,15 +1315,13 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { @Test public void testTransportKeys() throws Exception { TransportKeys transportKeys = createTransportKeys(); - Map<ContactId, TransportKeys> keys = - singletonMap(contactId, transportKeys); + Collection<KeySet> keys = + singletonList(new KeySet(keySetId, contactId, transportKeys)); context.checking(new Expectations() {{ // startTransaction() oneOf(database).startTransaction(); will(returnValue(txn)); // updateTransportKeys() - oneOf(database).containsContact(txn, contactId); - will(returnValue(true)); oneOf(database).containsTransport(txn, transportId); will(returnValue(true)); oneOf(database).updateTransportKeys(txn, keys); @@ -1337,22 +1347,22 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { } private TransportKeys createTransportKeys() { - SecretKey inPrevTagKey = TestUtils.getSecretKey(); - SecretKey inPrevHeaderKey = TestUtils.getSecretKey(); + SecretKey inPrevTagKey = getSecretKey(); + SecretKey inPrevHeaderKey = getSecretKey(); IncomingKeys inPrev = new IncomingKeys(inPrevTagKey, inPrevHeaderKey, 1, 123, new byte[4]); - SecretKey inCurrTagKey = TestUtils.getSecretKey(); - SecretKey inCurrHeaderKey = TestUtils.getSecretKey(); + SecretKey inCurrTagKey = getSecretKey(); + SecretKey inCurrHeaderKey = getSecretKey(); IncomingKeys inCurr = new IncomingKeys(inCurrTagKey, inCurrHeaderKey, 2, 234, new byte[4]); - SecretKey inNextTagKey = TestUtils.getSecretKey(); - SecretKey inNextHeaderKey = TestUtils.getSecretKey(); + SecretKey inNextTagKey = getSecretKey(); + SecretKey inNextHeaderKey = getSecretKey(); IncomingKeys inNext = new IncomingKeys(inNextTagKey, inNextHeaderKey, 3, 345, new byte[4]); - SecretKey outCurrTagKey = TestUtils.getSecretKey(); - SecretKey outCurrHeaderKey = TestUtils.getSecretKey(); + SecretKey outCurrTagKey = getSecretKey(); + SecretKey outCurrHeaderKey = getSecretKey(); OutgoingKeys outCurr = new OutgoingKeys(outCurrTagKey, outCurrHeaderKey, - 2, 456); + 2, 456, true); return new TransportKeys(transportId, inPrev, inCurr, inNext, outCurr); } diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java index d81cfd85149cdc1ec6e2a53d0b1a4de2b4afeb7a..157b6902be00063bd0215b5b8a7dd011f9cc2837 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java @@ -19,6 +19,8 @@ import org.briarproject.bramble.api.sync.MessageStatus; import org.briarproject.bramble.api.sync.ValidationManager.State; import org.briarproject.bramble.api.system.Clock; import org.briarproject.bramble.api.transport.IncomingKeys; +import org.briarproject.bramble.api.transport.KeySet; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.TransportKeys; import org.briarproject.bramble.system.SystemClock; @@ -34,7 +36,6 @@ import java.sql.Connection; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -42,6 +43,10 @@ import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; import static java.util.concurrent.TimeUnit.SECONDS; import static org.briarproject.bramble.api.db.Metadata.REMOVE; import static org.briarproject.bramble.api.sync.Group.Visibility.INVISIBLE; @@ -86,6 +91,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { private final Message message; private final TransportId transportId; private final ContactId contactId; + private final KeySetId keySetId, keySetId1; JdbcDatabaseTest() throws Exception { groupId = new GroupId(getRandomId()); @@ -101,6 +107,8 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { message = new Message(messageId, groupId, timestamp, raw); transportId = new TransportId("id"); contactId = new ContactId(1); + keySetId = new KeySetId(1); + keySetId1 = new KeySetId(2); } protected abstract JdbcDatabase createDatabase(DatabaseConfig config, @@ -190,9 +198,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // The contact has not seen the message, so it should be sendable Collection<MessageId> ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); ids = db.getMessagesToOffer(txn, contactId, 100); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); // Changing the status to seen = true should make the message unsendable db.raiseSeenFlag(txn, contactId, messageId); @@ -228,9 +236,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // Marking the message delivered should make it sendable db.setMessageState(txn, messageId, DELIVERED); ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); ids = db.getMessagesToOffer(txn, contactId, 100); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); // Marking the message invalid should make it unsendable db.setMessageState(txn, messageId, INVALID); @@ -279,9 +287,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // Sharing the group should make the message sendable db.setGroupVisibility(txn, contactId, groupId, true); ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); ids = db.getMessagesToOffer(txn, contactId, 100); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); // Unsharing the group should make the message unsendable db.setGroupVisibility(txn, contactId, groupId, false); @@ -324,9 +332,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // Sharing the message should make it sendable db.setMessageShared(txn, messageId); ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); ids = db.getMessagesToOffer(txn, contactId, 100); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); db.commitTransaction(txn); db.close(); @@ -352,7 +360,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // The message is just the right size to send ids = db.getMessagesToSend(txn, contactId, size); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); db.commitTransaction(txn); db.close(); @@ -384,7 +392,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { db.lowerAckFlag(txn, contactId, Arrays.asList(messageId, messageId1)); // Both message IDs should have been removed - assertEquals(Collections.emptyList(), db.getMessagesToAck(txn, + assertEquals(emptyList(), db.getMessagesToAck(txn, contactId, 1234)); // Raise the ack flag again @@ -415,7 +423,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // Retrieve the message from the database and mark it as sent Collection<MessageId> ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); db.updateExpiryTime(txn, contactId, messageId, Integer.MAX_VALUE); // The message should no longer be sendable @@ -626,31 +634,31 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // The group should not be visible to the contact assertEquals(INVISIBLE, db.getGroupVisibility(txn, contactId, groupId)); - assertEquals(Collections.emptyMap(), + assertEquals(emptyMap(), db.getGroupVisibility(txn, groupId)); // Make the group visible to the contact db.addGroupVisibility(txn, contactId, groupId, false); assertEquals(VISIBLE, db.getGroupVisibility(txn, contactId, groupId)); - assertEquals(Collections.singletonMap(contactId, false), + assertEquals(singletonMap(contactId, false), db.getGroupVisibility(txn, groupId)); // Share the group with the contact db.setGroupVisibility(txn, contactId, groupId, true); assertEquals(SHARED, db.getGroupVisibility(txn, contactId, groupId)); - assertEquals(Collections.singletonMap(contactId, true), + assertEquals(singletonMap(contactId, true), db.getGroupVisibility(txn, groupId)); // Unshare the group with the contact db.setGroupVisibility(txn, contactId, groupId, false); assertEquals(VISIBLE, db.getGroupVisibility(txn, contactId, groupId)); - assertEquals(Collections.singletonMap(contactId, false), + assertEquals(singletonMap(contactId, false), db.getGroupVisibility(txn, groupId)); // Make the group invisible again db.removeGroupVisibility(txn, contactId, groupId); assertEquals(INVISIBLE, db.getGroupVisibility(txn, contactId, groupId)); - assertEquals(Collections.emptyMap(), + assertEquals(emptyMap(), db.getGroupVisibility(txn, groupId)); db.commitTransaction(txn); @@ -660,48 +668,125 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { @Test public void testTransportKeys() throws Exception { TransportKeys keys = createTransportKeys(); + TransportKeys keys1 = createTransportKeys(); Database<Connection> db = open(false); Connection txn = db.startTransaction(); // Initially there should be no transport keys in the database - assertEquals(Collections.emptyMap(), - db.getTransportKeys(txn, transportId)); + assertEquals(emptyList(), db.getTransportKeys(txn, transportId)); // Add the contact, the transport and the transport keys db.addLocalAuthor(txn, localAuthor); assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), true, true)); db.addTransport(txn, transportId, 123); - db.addTransportKeys(txn, contactId, keys); + assertEquals(keySetId, db.addTransportKeys(txn, contactId, keys)); + assertEquals(keySetId1, db.addTransportKeys(txn, contactId, keys1)); // Retrieve the transport keys - Map<ContactId, TransportKeys> newKeys = - db.getTransportKeys(txn, transportId); - assertEquals(1, newKeys.size()); - Entry<ContactId, TransportKeys> e = - newKeys.entrySet().iterator().next(); - assertEquals(contactId, e.getKey()); - TransportKeys k = e.getValue(); - assertEquals(transportId, k.getTransportId()); - assertKeysEquals(keys.getPreviousIncomingKeys(), - k.getPreviousIncomingKeys()); - assertKeysEquals(keys.getCurrentIncomingKeys(), - k.getCurrentIncomingKeys()); - assertKeysEquals(keys.getNextIncomingKeys(), - k.getNextIncomingKeys()); - assertKeysEquals(keys.getCurrentOutgoingKeys(), - k.getCurrentOutgoingKeys()); + Collection<KeySet> allKeys = db.getTransportKeys(txn, transportId); + assertEquals(2, allKeys.size()); + for (KeySet ks : allKeys) { + assertEquals(contactId, ks.getContactId()); + if (ks.getKeySetId().equals(keySetId)) { + assertKeysEquals(keys, ks.getTransportKeys()); + } else { + assertEquals(keySetId1, ks.getKeySetId()); + assertKeysEquals(keys1, ks.getTransportKeys()); + } + } // Removing the contact should remove the transport keys db.removeContact(txn, contactId); - assertEquals(Collections.emptyMap(), - db.getTransportKeys(txn, transportId)); + assertEquals(emptyList(), db.getTransportKeys(txn, transportId)); + + db.commitTransaction(txn); + db.close(); + } + + @Test + public void testUnboundTransportKeys() throws Exception { + TransportKeys keys = createTransportKeys(); + TransportKeys keys1 = createTransportKeys(); + + Database<Connection> db = open(false); + Connection txn = db.startTransaction(); + + // Initially there should be no transport keys in the database + assertEquals(emptyList(), db.getTransportKeys(txn, transportId)); + + // Add the contact, the transport and the unbound transport keys + db.addLocalAuthor(txn, localAuthor); + assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), + true, true)); + db.addTransport(txn, transportId, 123); + assertEquals(keySetId, db.addTransportKeys(txn, null, keys)); + assertEquals(keySetId1, db.addTransportKeys(txn, null, keys1)); + + // Retrieve the transport keys + Collection<KeySet> allKeys = db.getTransportKeys(txn, transportId); + assertEquals(2, allKeys.size()); + for (KeySet ks : allKeys) { + assertNull(ks.getContactId()); + if (ks.getKeySetId().equals(keySetId)) { + assertKeysEquals(keys, ks.getTransportKeys()); + } else { + assertEquals(keySetId1, ks.getKeySetId()); + assertKeysEquals(keys1, ks.getTransportKeys()); + } + } + + // Bind the first set of transport keys + db.bindTransportKeys(txn, contactId, transportId, keySetId); + + // Retrieve the keys again - the first set should be bound + allKeys = db.getTransportKeys(txn, transportId); + assertEquals(2, allKeys.size()); + for (KeySet ks : allKeys) { + if (ks.getKeySetId().equals(keySetId)) { + assertEquals(contactId, ks.getContactId()); + assertKeysEquals(keys, ks.getTransportKeys()); + } else { + assertEquals(keySetId1, ks.getKeySetId()); + assertNull(ks.getContactId()); + assertKeysEquals(keys1, ks.getTransportKeys()); + } + } + + // Remove the unbound transport keys + db.removeTransportKeys(txn, transportId, keySetId1); + + // Retrieve the keys again - the second set should be gone + allKeys = db.getTransportKeys(txn, transportId); + assertEquals(1, allKeys.size()); + KeySet ks = allKeys.iterator().next(); + assertEquals(keySetId, ks.getKeySetId()); + assertEquals(contactId, ks.getContactId()); + assertKeysEquals(keys, ks.getTransportKeys()); + + // Removing the transport should remove the remaining transport keys + db.removeTransport(txn, transportId); + assertEquals(emptyList(), db.getTransportKeys(txn, transportId)); db.commitTransaction(txn); db.close(); } + private void assertKeysEquals(TransportKeys expected, + TransportKeys actual) { + assertEquals(expected.getTransportId(), actual.getTransportId()); + assertEquals(expected.getRotationPeriod(), actual.getRotationPeriod()); + assertKeysEquals(expected.getPreviousIncomingKeys(), + actual.getPreviousIncomingKeys()); + assertKeysEquals(expected.getCurrentIncomingKeys(), + actual.getCurrentIncomingKeys()); + assertKeysEquals(expected.getNextIncomingKeys(), + actual.getNextIncomingKeys()); + assertKeysEquals(expected.getCurrentOutgoingKeys(), + actual.getCurrentOutgoingKeys()); + } + private void assertKeysEquals(IncomingKeys expected, IncomingKeys actual) { assertArrayEquals(expected.getTagKey().getBytes(), actual.getTagKey().getBytes()); @@ -719,6 +804,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { actual.getHeaderKey().getBytes()); assertEquals(expected.getRotationPeriod(), actual.getRotationPeriod()); assertEquals(expected.getStreamCounter(), actual.getStreamCounter()); + assertEquals(expected.isActive(), actual.isActive()); } @Test @@ -735,18 +821,18 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), true, true)); db.addTransport(txn, transportId, 123); - db.updateTransportKeys(txn, Collections.singletonMap(contactId, keys)); + db.updateTransportKeys(txn, + singletonList(new KeySet(keySetId, contactId, keys))); // Increment the stream counter twice and retrieve the transport keys - db.incrementStreamCounter(txn, contactId, transportId, rotationPeriod); - db.incrementStreamCounter(txn, contactId, transportId, rotationPeriod); - Map<ContactId, TransportKeys> newKeys = - db.getTransportKeys(txn, transportId); + db.incrementStreamCounter(txn, transportId, keySetId); + db.incrementStreamCounter(txn, transportId, keySetId); + Collection<KeySet> newKeys = db.getTransportKeys(txn, transportId); assertEquals(1, newKeys.size()); - Entry<ContactId, TransportKeys> e = - newKeys.entrySet().iterator().next(); - assertEquals(contactId, e.getKey()); - TransportKeys k = e.getValue(); + KeySet ks = newKeys.iterator().next(); + assertEquals(keySetId, ks.getKeySetId()); + assertEquals(contactId, ks.getContactId()); + TransportKeys k = ks.getTransportKeys(); assertEquals(transportId, k.getTransportId()); OutgoingKeys outCurr = k.getCurrentOutgoingKeys(); assertEquals(rotationPeriod, outCurr.getRotationPeriod()); @@ -771,19 +857,19 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), true, true)); db.addTransport(txn, transportId, 123); - db.updateTransportKeys(txn, Collections.singletonMap(contactId, keys)); + db.updateTransportKeys(txn, + singletonList(new KeySet(keySetId, contactId, keys))); // Update the reordering window and retrieve the transport keys new Random().nextBytes(bitmap); - db.setReorderingWindow(txn, contactId, transportId, rotationPeriod, + db.setReorderingWindow(txn, keySetId, transportId, rotationPeriod, base + 1, bitmap); - Map<ContactId, TransportKeys> newKeys = - db.getTransportKeys(txn, transportId); + Collection<KeySet> newKeys = db.getTransportKeys(txn, transportId); assertEquals(1, newKeys.size()); - Entry<ContactId, TransportKeys> e = - newKeys.entrySet().iterator().next(); - assertEquals(contactId, e.getKey()); - TransportKeys k = e.getValue(); + KeySet ks = newKeys.iterator().next(); + assertEquals(keySetId, ks.getKeySetId()); + assertEquals(contactId, ks.getContactId()); + TransportKeys k = ks.getTransportKeys(); assertEquals(transportId, k.getTransportId()); IncomingKeys inCurr = k.getCurrentIncomingKeys(); assertEquals(rotationPeriod, inCurr.getRotationPeriod()); @@ -830,18 +916,18 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { db.addLocalAuthor(txn, localAuthor); Collection<ContactId> contacts = db.getContacts(txn, localAuthor.getId()); - assertEquals(Collections.emptyList(), contacts); + assertEquals(emptyList(), contacts); // Add a contact associated with the local author assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), true, true)); contacts = db.getContacts(txn, localAuthor.getId()); - assertEquals(Collections.singletonList(contactId), contacts); + assertEquals(singletonList(contactId), contacts); // Remove the local author - the contact should be removed db.removeLocalAuthor(txn, localAuthor.getId()); contacts = db.getContacts(txn, localAuthor.getId()); - assertEquals(Collections.emptyList(), contacts); + assertEquals(emptyList(), contacts); assertFalse(db.containsContact(txn, contactId)); db.commitTransaction(txn); @@ -1560,9 +1646,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // The message should be sendable Collection<MessageId> ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); ids = db.getMessagesToOffer(txn, contactId, 100); - assertEquals(Collections.singletonList(messageId), ids); + assertEquals(singletonList(messageId), ids); // The raw message should not be null assertNotNull(db.getRawMessage(txn, messageId)); @@ -1735,7 +1821,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { SecretKey outCurrTagKey = getSecretKey(); SecretKey outCurrHeaderKey = getSecretKey(); OutgoingKeys outCurr = new OutgoingKeys(outCurrTagKey, outCurrHeaderKey, - 2, 456); + 2, 456, true); return new TransportKeys(transportId, inPrev, inCurr, inNext, outCurr); } diff --git a/bramble-core/src/test/java/org/briarproject/bramble/transport/KeyManagerImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/transport/KeyManagerImplTest.java index edf073a4f6df64a53b1dceda71f6b06883b00f92..ed7a357ed092f32427496906bc27ddf009ae2697 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/transport/KeyManagerImplTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/transport/KeyManagerImplTest.java @@ -12,19 +12,20 @@ import org.briarproject.bramble.api.identity.AuthorId; import org.briarproject.bramble.api.plugin.PluginConfig; import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.simplex.SimplexPluginFactory; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.StreamContext; -import org.briarproject.bramble.test.BrambleTestCase; +import org.briarproject.bramble.test.BrambleMockTestCase; import org.jmock.Expectations; -import org.jmock.Mockery; import org.jmock.lib.concurrent.DeterministicExecutor; import org.junit.Before; import org.junit.Test; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.Random; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; import static org.briarproject.bramble.api.transport.TransportConstants.TAG_LENGTH; import static org.briarproject.bramble.test.TestUtils.getAuthor; import static org.briarproject.bramble.test.TestUtils.getRandomBytes; @@ -32,31 +33,29 @@ import static org.briarproject.bramble.test.TestUtils.getRandomId; import static org.briarproject.bramble.test.TestUtils.getSecretKey; import static org.junit.Assert.assertEquals; -public class KeyManagerImplTest extends BrambleTestCase { +public class KeyManagerImplTest extends BrambleMockTestCase { - private final Mockery context = new Mockery(); - private final KeyManagerImpl keyManager; private final DatabaseComponent db = context.mock(DatabaseComponent.class); private final PluginConfig pluginConfig = context.mock(PluginConfig.class); private final TransportKeyManagerFactory transportKeyManagerFactory = context.mock(TransportKeyManagerFactory.class); private final TransportKeyManager transportKeyManager = context.mock(TransportKeyManager.class); + private final DeterministicExecutor executor = new DeterministicExecutor(); private final Transaction txn = new Transaction(null, false); - private final ContactId contactId = new ContactId(42); - private final ContactId inactiveContactId = new ContactId(43); - private final TransportId transportId = new TransportId("tId"); - private final TransportId unknownTransportId = new TransportId("id"); + private final ContactId contactId = new ContactId(123); + private final ContactId inactiveContactId = new ContactId(234); + private final KeySetId keySetId = new KeySetId(345); + private final TransportId transportId = new TransportId("known"); + private final TransportId unknownTransportId = new TransportId("unknown"); private final StreamContext streamContext = new StreamContext(contactId, transportId, getSecretKey(), getSecretKey(), 1); private final byte[] tag = getRandomBytes(TAG_LENGTH); - public KeyManagerImplTest() { - keyManager = new KeyManagerImpl(db, executor, pluginConfig, - transportKeyManagerFactory); - } + private final KeyManagerImpl keyManager = new KeyManagerImpl(db, executor, + pluginConfig, transportKeyManagerFactory); @Before public void testStartService() throws Exception { @@ -70,8 +69,8 @@ public class KeyManagerImplTest extends BrambleTestCase { true, false)); SimplexPluginFactory pluginFactory = context.mock(SimplexPluginFactory.class); - Collection<SimplexPluginFactory> factories = Collections - .singletonList(pluginFactory); + Collection<SimplexPluginFactory> factories = + singletonList(pluginFactory); int maxLatency = 1337; context.checking(new Expectations() {{ @@ -110,7 +109,22 @@ public class KeyManagerImplTest extends BrambleTestCase { }}); keyManager.addContact(txn, contactId, secretKey, timestamp, alice); - context.assertIsSatisfied(); + } + + @Test + public void testAddUnboundKeys() throws Exception { + SecretKey secretKey = getSecretKey(); + long timestamp = System.currentTimeMillis(); + boolean alice = new Random().nextBoolean(); + + context.checking(new Expectations() {{ + oneOf(transportKeyManager).addUnboundKeys(txn, secretKey, + timestamp, alice); + will(returnValue(keySetId)); + }}); + + assertEquals(singletonMap(transportId, keySetId), + keyManager.addUnboundKeys(txn, secretKey, timestamp, alice)); } @Test @@ -138,7 +152,6 @@ public class KeyManagerImplTest extends BrambleTestCase { assertEquals(streamContext, keyManager.getStreamContext(contactId, transportId)); - context.assertIsSatisfied(); } @Test @@ -161,7 +174,6 @@ public class KeyManagerImplTest extends BrambleTestCase { assertEquals(streamContext, keyManager.getStreamContext(transportId, tag)); - context.assertIsSatisfied(); } @Test @@ -175,8 +187,6 @@ public class KeyManagerImplTest extends BrambleTestCase { keyManager.eventOccurred(event); executor.runUntilIdle(); assertEquals(null, keyManager.getStreamContext(contactId, transportId)); - - context.assertIsSatisfied(); } @Test @@ -196,8 +206,5 @@ public class KeyManagerImplTest extends BrambleTestCase { keyManager.eventOccurred(event); assertEquals(streamContext, keyManager.getStreamContext(inactiveContactId, transportId)); - - context.assertIsSatisfied(); } - } diff --git a/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java index 4dc1f980240d9d91daa57860c5da1ea40d47b93d..21350ae6d7414409091bb5fafaa81b6b7fb87a1b 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java @@ -8,6 +8,8 @@ import org.briarproject.bramble.api.db.Transaction; import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.system.Clock; import org.briarproject.bramble.api.transport.IncomingKeys; +import org.briarproject.bramble.api.transport.KeySet; +import org.briarproject.bramble.api.transport.KeySetId; import org.briarproject.bramble.api.transport.OutgoingKeys; import org.briarproject.bramble.api.transport.StreamContext; import org.briarproject.bramble.api.transport.TransportKeys; @@ -22,14 +24,13 @@ import org.junit.Test; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; -import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; import java.util.Random; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.briarproject.bramble.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE; import static org.briarproject.bramble.api.transport.TransportConstants.PROTOCOL_VERSION; @@ -37,8 +38,10 @@ import static org.briarproject.bramble.api.transport.TransportConstants.REORDERI import static org.briarproject.bramble.api.transport.TransportConstants.TAG_LENGTH; import static org.briarproject.bramble.util.ByteUtils.MAX_32_BIT_UNSIGNED; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; public class TransportKeyManagerImplTest extends BrambleMockTestCase { @@ -55,6 +58,9 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { private final long rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE; private final ContactId contactId = new ContactId(123); private final ContactId contactId1 = new ContactId(234); + private final KeySetId keySetId = new KeySetId(345); + private final KeySetId keySetId1 = new KeySetId(456); + private final KeySetId keySetId2 = new KeySetId(567); private final SecretKey tagKey = TestUtils.getSecretKey(); private final SecretKey headerKey = TestUtils.getSecretKey(); private final SecretKey masterKey = TestUtils.getSecretKey(); @@ -62,12 +68,16 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { @Test public void testKeysAreRotatedAtStartup() throws Exception { - Map<ContactId, TransportKeys> loaded = new LinkedHashMap<>(); - TransportKeys shouldRotate = createTransportKeys(900, 0); - TransportKeys shouldNotRotate = createTransportKeys(1000, 0); - loaded.put(contactId, shouldRotate); - loaded.put(contactId1, shouldNotRotate); - TransportKeys rotated = createTransportKeys(1000, 0); + TransportKeys shouldRotate = createTransportKeys(900, 0, true); + TransportKeys shouldNotRotate = createTransportKeys(1000, 0, true); + TransportKeys shouldRotate1 = createTransportKeys(999, 0, false); + Collection<KeySet> loaded = asList( + new KeySet(keySetId, contactId, shouldRotate), + new KeySet(keySetId1, contactId1, shouldNotRotate), + new KeySet(keySetId2, null, shouldRotate1) + ); + TransportKeys rotated = createTransportKeys(1000, 0, true); + TransportKeys rotated1 = createTransportKeys(1000, 0, false); Transaction txn = new Transaction(null, false); context.checking(new Expectations() {{ @@ -82,6 +92,8 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { will(returnValue(rotated)); oneOf(transportCrypto).rotateTransportKeys(shouldNotRotate, 1000); will(returnValue(shouldNotRotate)); + oneOf(transportCrypto).rotateTransportKeys(shouldRotate1, 1000); + will(returnValue(rotated1)); // Encode the tags (3 sets per contact) for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { exactly(6).of(transportCrypto).encodeTag( @@ -90,8 +102,10 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { will(new EncodeTagAction()); } // Save the keys that were rotated - oneOf(db).updateTransportKeys(txn, - Collections.singletonMap(contactId, rotated)); + oneOf(db).updateTransportKeys(txn, asList( + new KeySet(keySetId, contactId, rotated), + new KeySet(keySetId2, null, rotated1)) + ); // Schedule key rotation at the start of the next rotation period oneOf(scheduler).schedule(with(any(Runnable.class)), with(rotationPeriodLength - 1), with(MILLISECONDS)); @@ -101,18 +115,19 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { db, transportCrypto, dbExecutor, scheduler, clock, transportId, maxLatency); transportKeyManager.start(txn); + assertTrue(transportKeyManager.canSendOutgoingStreams(contactId)); } @Test public void testKeysAreRotatedWhenAddingContact() throws Exception { boolean alice = random.nextBoolean(); - TransportKeys transportKeys = createTransportKeys(999, 0); - TransportKeys rotated = createTransportKeys(1000, 0); + TransportKeys transportKeys = createTransportKeys(999, 0, true); + TransportKeys rotated = createTransportKeys(1000, 0, true); Transaction txn = new Transaction(null, false); context.checking(new Expectations() {{ oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey, - 999, alice); + 999, alice, true); will(returnValue(transportKeys)); // Get the current time (1 ms after start of rotation period 1000) oneOf(clock).currentTimeMillis(); @@ -129,6 +144,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { } // Save the keys oneOf(db).addTransportKeys(txn, contactId, rotated); + will(returnValue(keySetId)); }}); TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( @@ -138,6 +154,39 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { long timestamp = rotationPeriodLength * 1000 - 1; transportKeyManager.addContact(txn, contactId, masterKey, timestamp, alice); + assertTrue(transportKeyManager.canSendOutgoingStreams(contactId)); + } + + @Test + public void testKeysAreRotatedWhenAddingUnboundKeys() throws Exception { + boolean alice = random.nextBoolean(); + TransportKeys transportKeys = createTransportKeys(999, 0, false); + TransportKeys rotated = createTransportKeys(1000, 0, false); + Transaction txn = new Transaction(null, false); + + context.checking(new Expectations() {{ + oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey, + 999, alice, false); + will(returnValue(transportKeys)); + // Get the current time (1 ms after start of rotation period 1000) + oneOf(clock).currentTimeMillis(); + will(returnValue(rotationPeriodLength * 1000 + 1)); + // Rotate the transport keys + oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000); + will(returnValue(rotated)); + // Save the keys + oneOf(db).addTransportKeys(txn, null, rotated); + will(returnValue(keySetId)); + }}); + + TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( + db, transportCrypto, dbExecutor, scheduler, clock, transportId, + maxLatency); + // The timestamp is 1 ms before the start of rotation period 1000 + long timestamp = rotationPeriodLength * 1000 - 1; + assertEquals(keySetId, transportKeyManager.addUnboundKeys(txn, + masterKey, timestamp, alice)); + assertFalse(transportKeyManager.canSendOutgoingStreams(contactId)); } @Test @@ -149,6 +198,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { db, transportCrypto, dbExecutor, scheduler, clock, transportId, maxLatency); assertNull(transportKeyManager.getStreamContext(txn, contactId)); + assertFalse(transportKeyManager.canSendOutgoingStreams(contactId)); } @Test @@ -157,29 +207,10 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { boolean alice = random.nextBoolean(); // The stream counter has been exhausted TransportKeys transportKeys = createTransportKeys(1000, - MAX_32_BIT_UNSIGNED + 1); + MAX_32_BIT_UNSIGNED + 1, true); Transaction txn = new Transaction(null, false); - context.checking(new Expectations() {{ - oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey, - 1000, alice); - will(returnValue(transportKeys)); - // Get the current time (the start of rotation period 1000) - oneOf(clock).currentTimeMillis(); - will(returnValue(rotationPeriodLength * 1000)); - // Encode the tags (3 sets) - for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { - exactly(3).of(transportCrypto).encodeTag( - with(any(byte[].class)), with(tagKey), - with(PROTOCOL_VERSION), with(i)); - will(new EncodeTagAction()); - } - // Rotate the transport keys (the keys are unaffected) - oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000); - will(returnValue(transportKeys)); - // Save the keys - oneOf(db).addTransportKeys(txn, contactId, transportKeys); - }}); + expectAddContactNoRotation(alice, transportKeys, txn); TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( db, transportCrypto, dbExecutor, scheduler, clock, transportId, @@ -188,6 +219,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { long timestamp = rotationPeriodLength * 1000; transportKeyManager.addContact(txn, contactId, masterKey, timestamp, alice); + assertFalse(transportKeyManager.canSendOutgoingStreams(contactId)); assertNull(transportKeyManager.getStreamContext(txn, contactId)); } @@ -196,30 +228,14 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { boolean alice = random.nextBoolean(); // The stream counter can be used one more time before being exhausted TransportKeys transportKeys = createTransportKeys(1000, - MAX_32_BIT_UNSIGNED); + MAX_32_BIT_UNSIGNED, true); Transaction txn = new Transaction(null, false); + expectAddContactNoRotation(alice, transportKeys, txn); + context.checking(new Expectations() {{ - oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey, - 1000, alice); - will(returnValue(transportKeys)); - // Get the current time (the start of rotation period 1000) - oneOf(clock).currentTimeMillis(); - will(returnValue(rotationPeriodLength * 1000)); - // Encode the tags (3 sets) - for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { - exactly(3).of(transportCrypto).encodeTag( - with(any(byte[].class)), with(tagKey), - with(PROTOCOL_VERSION), with(i)); - will(new EncodeTagAction()); - } - // Rotate the transport keys (the keys are unaffected) - oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000); - will(returnValue(transportKeys)); - // Save the keys - oneOf(db).addTransportKeys(txn, contactId, transportKeys); // Increment the stream counter - oneOf(db).incrementStreamCounter(txn, contactId, transportId, 1000); + oneOf(db).incrementStreamCounter(txn, transportId, keySetId); }}); TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( @@ -230,6 +246,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { transportKeyManager.addContact(txn, contactId, masterKey, timestamp, alice); // The first request should return a stream context + assertTrue(transportKeyManager.canSendOutgoingStreams(contactId)); StreamContext ctx = transportKeyManager.getStreamContext(txn, contactId); assertNotNull(ctx); @@ -239,6 +256,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { assertEquals(headerKey, ctx.getHeaderKey()); assertEquals(MAX_32_BIT_UNSIGNED, ctx.getStreamNumber()); // The second request should return null, the counter is exhausted + assertFalse(transportKeyManager.canSendOutgoingStreams(contactId)); assertNull(transportKeyManager.getStreamContext(txn, contactId)); } @@ -246,29 +264,10 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { public void testIncomingStreamContextIsNullIfTagIsNotFound() throws Exception { boolean alice = random.nextBoolean(); - TransportKeys transportKeys = createTransportKeys(1000, 0); + TransportKeys transportKeys = createTransportKeys(1000, 0, true); Transaction txn = new Transaction(null, false); - context.checking(new Expectations() {{ - oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey, - 1000, alice); - will(returnValue(transportKeys)); - // Get the current time (the start of rotation period 1000) - oneOf(clock).currentTimeMillis(); - will(returnValue(rotationPeriodLength * 1000)); - // Encode the tags (3 sets) - for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { - exactly(3).of(transportCrypto).encodeTag( - with(any(byte[].class)), with(tagKey), - with(PROTOCOL_VERSION), with(i)); - will(new EncodeTagAction()); - } - // Rotate the transport keys (the keys are unaffected) - oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000); - will(returnValue(transportKeys)); - // Save the keys - oneOf(db).addTransportKeys(txn, contactId, transportKeys); - }}); + expectAddContactNoRotation(alice, transportKeys, txn); TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( db, transportCrypto, dbExecutor, scheduler, clock, transportId, @@ -277,6 +276,8 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { long timestamp = rotationPeriodLength * 1000; transportKeyManager.addContact(txn, contactId, masterKey, timestamp, alice); + assertTrue(transportKeyManager.canSendOutgoingStreams(contactId)); + // The tag should not be recognised assertNull(transportKeyManager.getStreamContext(txn, new byte[TAG_LENGTH])); } @@ -284,14 +285,15 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { @Test public void testTagIsNotRecognisedTwice() throws Exception { boolean alice = random.nextBoolean(); - TransportKeys transportKeys = createTransportKeys(1000, 0); + TransportKeys transportKeys = createTransportKeys(1000, 0, true); + Transaction txn = new Transaction(null, false); + // Keep a copy of the tags List<byte[]> tags = new ArrayList<>(); - Transaction txn = new Transaction(null, false); context.checking(new Expectations() {{ oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey, - 1000, alice); + 1000, alice, true); will(returnValue(transportKeys)); // Get the current time (the start of rotation period 1000) oneOf(clock).currentTimeMillis(); @@ -308,13 +310,14 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { will(returnValue(transportKeys)); // Save the keys oneOf(db).addTransportKeys(txn, contactId, transportKeys); + will(returnValue(keySetId)); // Encode a new tag after sliding the window oneOf(transportCrypto).encodeTag(with(any(byte[].class)), with(tagKey), with(PROTOCOL_VERSION), with((long) REORDERING_WINDOW_SIZE)); will(new EncodeTagAction(tags)); // Save the reordering window (previous rotation period, base 1) - oneOf(db).setReorderingWindow(txn, contactId, transportId, 999, + oneOf(db).setReorderingWindow(txn, keySetId, transportId, 999, 1, new byte[REORDERING_WINDOW_SIZE / 8]); }}); @@ -325,6 +328,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { long timestamp = rotationPeriodLength * 1000; transportKeyManager.addContact(txn, contactId, masterKey, timestamp, alice); + assertTrue(transportKeyManager.canSendOutgoingStreams(contactId)); // Use the first tag (previous rotation period, stream number 0) assertEquals(REORDERING_WINDOW_SIZE * 3, tags.size()); byte[] tag = tags.get(0); @@ -344,10 +348,10 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { @Test public void testKeysAreRotatedToCurrentPeriod() throws Exception { - TransportKeys transportKeys = createTransportKeys(1000, 0); - Map<ContactId, TransportKeys> loaded = - Collections.singletonMap(contactId, transportKeys); - TransportKeys rotated = createTransportKeys(1001, 0); + TransportKeys transportKeys = createTransportKeys(1000, 0, true); + Collection<KeySet> loaded = + singletonList(new KeySet(keySetId, contactId, transportKeys)); + TransportKeys rotated = createTransportKeys(1001, 0, true); Transaction txn = new Transaction(null, false); Transaction txn1 = new Transaction(null, false); @@ -393,7 +397,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { } // Save the keys that were rotated oneOf(db).updateTransportKeys(txn1, - Collections.singletonMap(contactId, rotated)); + singletonList(new KeySet(keySetId, contactId, rotated))); // Schedule key rotation at the start of the next rotation period oneOf(scheduler).schedule(with(any(Runnable.class)), with(rotationPeriodLength), with(MILLISECONDS)); @@ -406,10 +410,197 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { db, transportCrypto, dbExecutor, scheduler, clock, transportId, maxLatency); transportKeyManager.start(txn); + assertTrue(transportKeyManager.canSendOutgoingStreams(contactId)); + } + + @Test + public void testBindingAndActivatingKeys() throws Exception { + boolean alice = random.nextBoolean(); + TransportKeys transportKeys = createTransportKeys(1000, 0, false); + Transaction txn = new Transaction(null, false); + + expectAddUnboundKeysNoRotation(alice, transportKeys, txn); + + context.checking(new Expectations() {{ + // When the keys are bound, encode the tags (3 sets) + for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { + exactly(3).of(transportCrypto).encodeTag( + with(any(byte[].class)), with(tagKey), + with(PROTOCOL_VERSION), with(i)); + will(new EncodeTagAction()); + } + // Save the key binding + oneOf(db).bindTransportKeys(txn, contactId, transportId, keySetId); + // Activate the keys + oneOf(db).setTransportKeysActive(txn, transportId, keySetId); + // Increment the stream counter + oneOf(db).incrementStreamCounter(txn, transportId, keySetId); + }}); + + TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( + db, transportCrypto, dbExecutor, scheduler, clock, transportId, + maxLatency); + // The timestamp is at the start of rotation period 1000 + long timestamp = rotationPeriodLength * 1000; + assertEquals(keySetId, transportKeyManager.addUnboundKeys(txn, + masterKey, timestamp, alice)); + // The keys are unbound so no stream context should be returned + assertFalse(transportKeyManager.canSendOutgoingStreams(contactId)); + assertNull(transportKeyManager.getStreamContext(txn, contactId)); + transportKeyManager.bindKeys(txn, contactId, keySetId); + // The keys are inactive so no stream context should be returned + assertFalse(transportKeyManager.canSendOutgoingStreams(contactId)); + assertNull(transportKeyManager.getStreamContext(txn, contactId)); + transportKeyManager.activateKeys(txn, keySetId); + // The keys are active so a stream context should be returned + assertTrue(transportKeyManager.canSendOutgoingStreams(contactId)); + StreamContext ctx = transportKeyManager.getStreamContext(txn, + contactId); + assertNotNull(ctx); + assertEquals(contactId, ctx.getContactId()); + assertEquals(transportId, ctx.getTransportId()); + assertEquals(tagKey, ctx.getTagKey()); + assertEquals(headerKey, ctx.getHeaderKey()); + assertEquals(0L, ctx.getStreamNumber()); + } + + @Test + public void testRecognisingTagActivatesOutgoingKeys() throws Exception { + boolean alice = random.nextBoolean(); + TransportKeys transportKeys = createTransportKeys(1000, 0, false); + Transaction txn = new Transaction(null, false); + + // Keep a copy of the tags + List<byte[]> tags = new ArrayList<>(); + + expectAddUnboundKeysNoRotation(alice, transportKeys, txn); + + context.checking(new Expectations() {{ + // When the keys are bound, encode the tags (3 sets) + for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { + exactly(3).of(transportCrypto).encodeTag( + with(any(byte[].class)), with(tagKey), + with(PROTOCOL_VERSION), with(i)); + will(new EncodeTagAction(tags)); + } + // Save the key binding + oneOf(db).bindTransportKeys(txn, contactId, transportId, keySetId); + // Encode a new tag after sliding the window + oneOf(transportCrypto).encodeTag(with(any(byte[].class)), + with(tagKey), with(PROTOCOL_VERSION), + with((long) REORDERING_WINDOW_SIZE)); + will(new EncodeTagAction(tags)); + // Save the reordering window (previous rotation period, base 1) + oneOf(db).setReorderingWindow(txn, keySetId, transportId, 999, + 1, new byte[REORDERING_WINDOW_SIZE / 8]); + // Activate the keys + oneOf(db).setTransportKeysActive(txn, transportId, keySetId); + // Increment the stream counter + oneOf(db).incrementStreamCounter(txn, transportId, keySetId); + }}); + + TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( + db, transportCrypto, dbExecutor, scheduler, clock, transportId, + maxLatency); + // The timestamp is at the start of rotation period 1000 + long timestamp = rotationPeriodLength * 1000; + assertEquals(keySetId, transportKeyManager.addUnboundKeys(txn, + masterKey, timestamp, alice)); + transportKeyManager.bindKeys(txn, contactId, keySetId); + // The keys are inactive so no stream context should be returned + assertFalse(transportKeyManager.canSendOutgoingStreams(contactId)); + assertNull(transportKeyManager.getStreamContext(txn, contactId)); + // Recognising an incoming tag should activate the outgoing keys + assertEquals(REORDERING_WINDOW_SIZE * 3, tags.size()); + byte[] tag = tags.get(0); + StreamContext ctx = transportKeyManager.getStreamContext(txn, tag); + assertNotNull(ctx); + assertEquals(contactId, ctx.getContactId()); + assertEquals(transportId, ctx.getTransportId()); + assertEquals(tagKey, ctx.getTagKey()); + assertEquals(headerKey, ctx.getHeaderKey()); + assertEquals(0L, ctx.getStreamNumber()); + // The keys are active so a stream context should be returned + assertTrue(transportKeyManager.canSendOutgoingStreams(contactId)); + ctx = transportKeyManager.getStreamContext(txn, contactId); + assertNotNull(ctx); + assertEquals(contactId, ctx.getContactId()); + assertEquals(transportId, ctx.getTransportId()); + assertEquals(tagKey, ctx.getTagKey()); + assertEquals(headerKey, ctx.getHeaderKey()); + assertEquals(0L, ctx.getStreamNumber()); + } + + @Test + public void testRemovingUnboundKeys() throws Exception { + boolean alice = random.nextBoolean(); + TransportKeys transportKeys = createTransportKeys(1000, 0, false); + Transaction txn = new Transaction(null, false); + + expectAddUnboundKeysNoRotation(alice, transportKeys, txn); + + context.checking(new Expectations() {{ + // Remove the unbound keys + oneOf(db).removeTransportKeys(txn, transportId, keySetId); + }}); + + TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( + db, transportCrypto, dbExecutor, scheduler, clock, transportId, + maxLatency); + // The timestamp is at the start of rotation period 1000 + long timestamp = rotationPeriodLength * 1000; + assertEquals(keySetId, transportKeyManager.addUnboundKeys(txn, + masterKey, timestamp, alice)); + assertFalse(transportKeyManager.canSendOutgoingStreams(contactId)); + transportKeyManager.removeKeys(txn, keySetId); + assertFalse(transportKeyManager.canSendOutgoingStreams(contactId)); + } + + private void expectAddContactNoRotation(boolean alice, + TransportKeys transportKeys, Transaction txn) throws Exception { + context.checking(new Expectations() {{ + oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey, + 1000, alice, true); + will(returnValue(transportKeys)); + // Get the current time (the start of rotation period 1000) + oneOf(clock).currentTimeMillis(); + will(returnValue(rotationPeriodLength * 1000)); + // Encode the tags (3 sets) + for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { + exactly(3).of(transportCrypto).encodeTag( + with(any(byte[].class)), with(tagKey), + with(PROTOCOL_VERSION), with(i)); + will(new EncodeTagAction()); + } + // Rotate the transport keys (the keys are unaffected) + oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000); + will(returnValue(transportKeys)); + // Save the keys + oneOf(db).addTransportKeys(txn, contactId, transportKeys); + will(returnValue(keySetId)); + }}); + } + + private void expectAddUnboundKeysNoRotation(boolean alice, + TransportKeys transportKeys, Transaction txn) throws Exception { + context.checking(new Expectations() {{ + oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey, + 1000, alice, false); + will(returnValue(transportKeys)); + // Get the current time (the start of rotation period 1000) + oneOf(clock).currentTimeMillis(); + will(returnValue(rotationPeriodLength * 1000)); + // Rotate the transport keys (the keys are unaffected) + oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000); + will(returnValue(transportKeys)); + // Save the unbound keys + oneOf(db).addTransportKeys(txn, null, transportKeys); + will(returnValue(keySetId)); + }}); } private TransportKeys createTransportKeys(long rotationPeriod, - long streamCounter) { + long streamCounter, boolean active) { IncomingKeys inPrev = new IncomingKeys(tagKey, headerKey, rotationPeriod - 1); IncomingKeys inCurr = new IncomingKeys(tagKey, headerKey, @@ -417,7 +608,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { IncomingKeys inNext = new IncomingKeys(tagKey, headerKey, rotationPeriod + 1); OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey, - rotationPeriod, streamCounter); + rotationPeriod, streamCounter, active); return new TransportKeys(transportId, inPrev, inCurr, inNext, outCurr); }