diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactId.java b/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactId.java
index 9e11e555783018e242eef89a7141e8565f68ea12..21293599febac3854411de04de5ace41017165e4 100644
--- a/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactId.java
+++ b/bramble-api/src/main/java/org/briarproject/bramble/api/contact/ContactId.java
@@ -6,7 +6,7 @@ import javax.annotation.concurrent.Immutable;
 
 /**
  * Type-safe wrapper for an integer that uniquely identifies a contact within
- * the scope of a single node.
+ * the scope of the local device.
  */
 @Immutable
 @NotNullByDefault
diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/db/DatabaseComponent.java b/bramble-api/src/main/java/org/briarproject/bramble/api/db/DatabaseComponent.java
index 08fbe45405b0b5c3a1011e4780156a7d157c6219..45f721f22df00a5160352619a0801955a037fbd0 100644
--- a/bramble-api/src/main/java/org/briarproject/bramble/api/db/DatabaseComponent.java
+++ b/bramble-api/src/main/java/org/briarproject/bramble/api/db/DatabaseComponent.java
@@ -18,6 +18,8 @@ import org.briarproject.bramble.api.sync.MessageId;
 import org.briarproject.bramble.api.sync.MessageStatus;
 import org.briarproject.bramble.api.sync.Offer;
 import org.briarproject.bramble.api.sync.Request;
+import org.briarproject.bramble.api.transport.KeySet;
+import org.briarproject.bramble.api.transport.KeySetId;
 import org.briarproject.bramble.api.transport.TransportKeys;
 
 import java.util.Collection;
@@ -102,10 +104,11 @@ public interface DatabaseComponent {
 			throws DbException;
 
 	/**
-	 * Stores transport keys for a newly added contact.
+	 * Stores the given transport keys, optionally binding them to the given
+	 * contact, and returns a key set ID.
 	 */
-	void addTransportKeys(Transaction txn, ContactId c, TransportKeys k)
-			throws DbException;
+	KeySetId addTransportKeys(Transaction txn, @Nullable ContactId c,
+			TransportKeys k) throws DbException;
 
 	/**
 	 * Returns true if the database contains the given contact for the given
@@ -394,8 +397,8 @@ public interface DatabaseComponent {
 	 * <p/>
 	 * Read-only.
 	 */
-	Map<ContactId, TransportKeys> getTransportKeys(Transaction txn,
-			TransportId t) throws DbException;
+	Collection<KeySet> getTransportKeys(Transaction txn, TransportId t)
+			throws DbException;
 
 	/**
 	 * Increments the outgoing stream counter for the given contact and
@@ -507,15 +510,15 @@ public interface DatabaseComponent {
 			Collection<MessageId> dependencies) throws DbException;
 
 	/**
-	 * Sets the reordering window for the given contact and transport in the
+	 * Sets the reordering window for the given key set and transport in the
 	 * given rotation period.
 	 */
-	void setReorderingWindow(Transaction txn, ContactId c, TransportId t,
+	void setReorderingWindow(Transaction txn, KeySetId k, TransportId t,
 			long rotationPeriod, long base, byte[] bitmap) throws DbException;
 
 	/**
 	 * Stores the given transport keys, deleting any keys they have replaced.
 	 */
-	void updateTransportKeys(Transaction txn,
-			Map<ContactId, TransportKeys> keys) throws DbException;
+	void updateTransportKeys(Transaction txn, Collection<KeySet> keys)
+			throws DbException;
 }
diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySet.java b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySet.java
new file mode 100644
index 0000000000000000000000000000000000000000..9cc8f63c294a6cd0b4674944e51307f72058ddf6
--- /dev/null
+++ b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySet.java
@@ -0,0 +1,51 @@
+package org.briarproject.bramble.api.transport;
+
+import org.briarproject.bramble.api.contact.ContactId;
+import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
+
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.Immutable;
+
+/**
+ * A set of transport keys for communicating with a contact. If the keys have
+ * not yet been bound to a contact, {@link #getContactId()}} returns null.
+ */
+@Immutable
+@NotNullByDefault
+public class KeySet {
+
+	private final KeySetId keySetId;
+	@Nullable
+	private final ContactId contactId;
+	private final TransportKeys transportKeys;
+
+	public KeySet(KeySetId keySetId, @Nullable ContactId contactId,
+			TransportKeys transportKeys) {
+		this.keySetId = keySetId;
+		this.contactId = contactId;
+		this.transportKeys = transportKeys;
+	}
+
+	public KeySetId getKeySetId() {
+		return keySetId;
+	}
+
+	@Nullable
+	public ContactId getContactId() {
+		return contactId;
+	}
+
+	public TransportKeys getTransportKeys() {
+		return transportKeys;
+	}
+
+	@Override
+	public int hashCode() {
+		return keySetId.hashCode();
+	}
+
+	@Override
+	public boolean equals(Object o) {
+		return o instanceof KeySet && keySetId.equals(((KeySet) o).keySetId);
+	}
+}
diff --git a/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySetId.java b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySetId.java
new file mode 100644
index 0000000000000000000000000000000000000000..1f872e72a34d7ffb480423d11bbe1aa0db1bdf82
--- /dev/null
+++ b/bramble-api/src/main/java/org/briarproject/bramble/api/transport/KeySetId.java
@@ -0,0 +1,36 @@
+package org.briarproject.bramble.api.transport;
+
+import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
+
+import javax.annotation.concurrent.Immutable;
+
+/**
+ * Type-safe wrapper for an integer that uniquely identifies a set of transport
+ * keys within the scope of the local device.
+ * <p/>
+ * Key sets created on a given device must have increasing identifiers.
+ */
+@Immutable
+@NotNullByDefault
+public class KeySetId {
+
+	private final int id;
+
+	public KeySetId(int id) {
+		this.id = id;
+	}
+
+	public int getInt() {
+		return id;
+	}
+
+	@Override
+	public int hashCode() {
+		return id;
+	}
+
+	@Override
+	public boolean equals(Object o) {
+		return o instanceof KeySetId && id == ((KeySetId) o).id;
+	}
+}
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java b/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java
index bc3166d5622e605a1fdddcbf0858bc438b2216ce..e8597eaba7deca87148ede882f806b7369d35f70 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java
@@ -21,6 +21,8 @@ import org.briarproject.bramble.api.sync.Message;
 import org.briarproject.bramble.api.sync.MessageId;
 import org.briarproject.bramble.api.sync.MessageStatus;
 import org.briarproject.bramble.api.sync.ValidationManager.State;
+import org.briarproject.bramble.api.transport.KeySet;
+import org.briarproject.bramble.api.transport.KeySetId;
 import org.briarproject.bramble.api.transport.TransportKeys;
 
 import java.util.Collection;
@@ -123,9 +125,10 @@ interface Database<T> {
 			throws DbException;
 
 	/**
-	 * Stores transport keys for a newly added contact.
+	 * Stores the given transport keys, optionally binding them to the given
+	 * contact, and returns a key set ID.
 	 */
-	void addTransportKeys(T txn, ContactId c, TransportKeys k)
+	KeySetId addTransportKeys(T txn, @Nullable ContactId c, TransportKeys k)
 			throws DbException;
 
 	/**
@@ -486,7 +489,7 @@ interface Database<T> {
 	 * <p/>
 	 * Read-only.
 	 */
-	Map<ContactId, TransportKeys> getTransportKeys(T txn, TransportId t)
+	Collection<KeySet> getTransportKeys(T txn, TransportId t)
 			throws DbException;
 
 	/**
@@ -619,10 +622,10 @@ interface Database<T> {
 	void setMessageState(T txn, MessageId m, State state) throws DbException;
 
 	/**
-	 * Sets the reordering window for the given contact and transport in the
+	 * Sets the reordering window for the given key set and transport in the
 	 * given rotation period.
 	 */
-	void setReorderingWindow(T txn, ContactId c, TransportId t,
+	void setReorderingWindow(T txn, KeySetId k, TransportId t,
 			long rotationPeriod, long base, byte[] bitmap) throws DbException;
 
 	/**
@@ -636,6 +639,5 @@ interface Database<T> {
 	/**
 	 * Stores the given transport keys, deleting any keys they have replaced.
 	 */
-	void updateTransportKeys(T txn, Map<ContactId, TransportKeys> keys)
-			throws DbException;
+	void updateTransportKeys(T txn, Collection<KeySet> keys) throws DbException;
 }
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java
index 013233f395ceac5ff0f65cf5a0044473238858c4..f90f9e3ff18631275a73cfd890bd6f51dd4ad9c6 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java
@@ -51,15 +51,15 @@ import org.briarproject.bramble.api.sync.event.MessageToAckEvent;
 import org.briarproject.bramble.api.sync.event.MessageToRequestEvent;
 import org.briarproject.bramble.api.sync.event.MessagesAckedEvent;
 import org.briarproject.bramble.api.sync.event.MessagesSentEvent;
+import org.briarproject.bramble.api.transport.KeySet;
+import org.briarproject.bramble.api.transport.KeySetId;
 import org.briarproject.bramble.api.transport.TransportKeys;
 
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Map.Entry;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.locks.ReentrantReadWriteLock;
 import java.util.logging.Logger;
@@ -234,15 +234,15 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
 	}
 
 	@Override
-	public void addTransportKeys(Transaction transaction, ContactId c,
-			TransportKeys k) throws DbException {
+	public KeySetId addTransportKeys(Transaction transaction,
+			@Nullable ContactId c, TransportKeys k) throws DbException {
 		if (transaction.isReadOnly()) throw new IllegalArgumentException();
 		T txn = unbox(transaction);
-		if (!db.containsContact(txn, c))
+		if (c != null && !db.containsContact(txn, c))
 			throw new NoSuchContactException();
 		if (!db.containsTransport(txn, k.getTransportId()))
 			throw new NoSuchTransportException();
-		db.addTransportKeys(txn, c, k);
+		return db.addTransportKeys(txn, c, k);
 	}
 
 	@Override
@@ -586,8 +586,8 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
 	}
 
 	@Override
-	public Map<ContactId, TransportKeys> getTransportKeys(
-			Transaction transaction, TransportId t) throws DbException {
+	public Collection<KeySet> getTransportKeys(Transaction transaction,
+			TransportId t) throws DbException {
 		T txn = unbox(transaction);
 		if (!db.containsTransport(txn, t))
 			throw new NoSuchTransportException();
@@ -858,31 +858,25 @@ class DatabaseComponentImpl<T> implements DatabaseComponent {
 	}
 
 	@Override
-	public void setReorderingWindow(Transaction transaction, ContactId c,
+	public void setReorderingWindow(Transaction transaction, KeySetId k,
 			TransportId t, long rotationPeriod, long base, byte[] bitmap)
 			throws DbException {
 		if (transaction.isReadOnly()) throw new IllegalArgumentException();
 		T txn = unbox(transaction);
-		if (!db.containsContact(txn, c))
-			throw new NoSuchContactException();
 		if (!db.containsTransport(txn, t))
 			throw new NoSuchTransportException();
-		db.setReorderingWindow(txn, c, t, rotationPeriod, base, bitmap);
+		db.setReorderingWindow(txn, k, t, rotationPeriod, base, bitmap);
 	}
 
 	@Override
 	public void updateTransportKeys(Transaction transaction,
-			Map<ContactId, TransportKeys> keys) throws DbException {
+			Collection<KeySet> keys) throws DbException {
 		if (transaction.isReadOnly()) throw new IllegalArgumentException();
 		T txn = unbox(transaction);
-		Map<ContactId, TransportKeys> filtered = new HashMap<>();
-		for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
-			ContactId c = e.getKey();
-			TransportKeys k = e.getValue();
-			if (db.containsContact(txn, c)
-					&& db.containsTransport(txn, k.getTransportId())) {
-				filtered.put(c, k);
-			}
+		Collection<KeySet> filtered = new ArrayList<>();
+		for (KeySet ks : keys) {
+			TransportId t = ks.getTransportKeys().getTransportId();
+			if (db.containsTransport(txn, t)) filtered.add(ks);
 		}
 		db.updateTransportKeys(txn, filtered);
 	}
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java b/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java
index 51fa9ae309c39b93421e37673745ce143e8e59f5..327203d63005514b6cb66ef389641c8f49859a99 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java
@@ -25,6 +25,8 @@ import org.briarproject.bramble.api.sync.MessageStatus;
 import org.briarproject.bramble.api.sync.ValidationManager.State;
 import org.briarproject.bramble.api.system.Clock;
 import org.briarproject.bramble.api.transport.IncomingKeys;
+import org.briarproject.bramble.api.transport.KeySet;
+import org.briarproject.bramble.api.transport.KeySetId;
 import org.briarproject.bramble.api.transport.OutgoingKeys;
 import org.briarproject.bramble.api.transport.TransportKeys;
 
@@ -223,37 +225,43 @@ abstract class JdbcDatabase implements Database<Connection> {
 					+ " maxLatency INT NOT NULL,"
 					+ " PRIMARY KEY (transportId))";
 
-	private static final String CREATE_INCOMING_KEYS =
-			"CREATE TABLE incomingKeys"
-					+ " (contactId INT NOT NULL,"
-					+ " transportId _STRING NOT NULL,"
+	private static final String CREATE_OUTGOING_KEYS =
+			"CREATE TABLE outgoingKeys"
+					+ " (transportId _STRING NOT NULL,"
+					+ " keySetId _COUNTER,"
 					+ " rotationPeriod BIGINT NOT NULL,"
+					+ " contactId INT," // Null if keys are not bound
 					+ " tagKey _SECRET NOT NULL,"
 					+ " headerKey _SECRET NOT NULL,"
-					+ " base BIGINT NOT NULL,"
-					+ " bitmap _BINARY NOT NULL,"
-					+ " PRIMARY KEY (contactId, transportId, rotationPeriod),"
-					+ " FOREIGN KEY (contactId)"
-					+ " REFERENCES contacts (contactId)"
-					+ " ON DELETE CASCADE,"
+					+ " stream BIGINT NOT NULL,"
+					+ " PRIMARY KEY (transportId, keySetId),"
 					+ " FOREIGN KEY (transportId)"
 					+ " REFERENCES transports (transportId)"
+					+ " ON DELETE CASCADE,"
+					+ " UNIQUE (keySetId),"
+					+ " FOREIGN KEY (contactId)"
+					+ " REFERENCES contacts (contactId)"
 					+ " ON DELETE CASCADE)";
 
-	private static final String CREATE_OUTGOING_KEYS =
-			"CREATE TABLE outgoingKeys"
-					+ " (contactId INT NOT NULL,"
-					+ " transportId _STRING NOT NULL,"
+	private static final String CREATE_INCOMING_KEYS =
+			"CREATE TABLE incomingKeys"
+					+ " (transportId _STRING NOT NULL,"
+					+ " keySetId INT NOT NULL,"
 					+ " rotationPeriod BIGINT NOT NULL,"
+					+ " contactId INT," // Null if keys are not bound
 					+ " tagKey _SECRET NOT NULL,"
 					+ " headerKey _SECRET NOT NULL,"
-					+ " stream BIGINT NOT NULL,"
-					+ " PRIMARY KEY (contactId, transportId),"
-					+ " FOREIGN KEY (contactId)"
-					+ " REFERENCES contacts (contactId)"
-					+ " ON DELETE CASCADE,"
+					+ " base BIGINT NOT NULL,"
+					+ " bitmap _BINARY NOT NULL,"
+					+ " PRIMARY KEY (transportId, keySetId, rotationPeriod),"
 					+ " FOREIGN KEY (transportId)"
 					+ " REFERENCES transports (transportId)"
+					+ " ON DELETE CASCADE,"
+					+ " FOREIGN KEY (keySetId)"
+					+ " REFERENCES outgoingKeys (keySetId)"
+					+ " ON DELETE CASCADE,"
+					+ " FOREIGN KEY (contactId)"
+					+ " REFERENCES contacts (contactId)"
 					+ " ON DELETE CASCADE)";
 
 	private static final String INDEX_CONTACTS_BY_AUTHOR_ID =
@@ -415,8 +423,8 @@ abstract class JdbcDatabase implements Database<Connection> {
 			s.executeUpdate(insertTypeNames(CREATE_OFFERS));
 			s.executeUpdate(insertTypeNames(CREATE_STATUSES));
 			s.executeUpdate(insertTypeNames(CREATE_TRANSPORTS));
-			s.executeUpdate(insertTypeNames(CREATE_INCOMING_KEYS));
 			s.executeUpdate(insertTypeNames(CREATE_OUTGOING_KEYS));
+			s.executeUpdate(insertTypeNames(CREATE_INCOMING_KEYS));
 			s.close();
 		} catch (SQLException e) {
 			tryToClose(s);
@@ -865,62 +873,78 @@ abstract class JdbcDatabase implements Database<Connection> {
 	}
 
 	@Override
-	public void addTransportKeys(Connection txn, ContactId c, TransportKeys k)
-			throws DbException {
+	public KeySetId addTransportKeys(Connection txn, @Nullable ContactId c,
+			TransportKeys k) throws DbException {
 		PreparedStatement ps = null;
+		ResultSet rs = null;
 		try {
+			// Store the outgoing keys
+			String sql = "INSERT INTO outgoingKeys (contactId, transportId,"
+					+ " rotationPeriod, tagKey, headerKey, stream)"
+					+ " VALUES (?, ?, ?, ?, ?, ?)";
+			ps = txn.prepareStatement(sql);
+			if (c == null) ps.setNull(1, INTEGER);
+			else ps.setInt(1, c.getInt());
+			ps.setString(2, k.getTransportId().getString());
+			OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
+			ps.setLong(3, outCurr.getRotationPeriod());
+			ps.setBytes(4, outCurr.getTagKey().getBytes());
+			ps.setBytes(5, outCurr.getHeaderKey().getBytes());
+			ps.setLong(6, outCurr.getStreamCounter());
+			int affected = ps.executeUpdate();
+			if (affected != 1) throw new DbStateException();
+			ps.close();
+			// Get the new (highest) key set ID
+			sql = "SELECT keySetId FROM outgoingKeys"
+					+ " ORDER BY keySetId DESC LIMIT 1";
+			ps = txn.prepareStatement(sql);
+			rs = ps.executeQuery();
+			if (!rs.next()) throw new DbStateException();
+			KeySetId keySetId = new KeySetId(rs.getInt(1));
+			if (rs.next()) throw new DbStateException();
+			rs.close();
+			ps.close();
 			// Store the incoming keys
-			String sql = "INSERT INTO incomingKeys (contactId, transportId,"
+			sql = "INSERT INTO incomingKeys (keySetId, contactId, transportId,"
 					+ " rotationPeriod, tagKey, headerKey, base, bitmap)"
-					+ " VALUES (?, ?, ?, ?, ?, ?, ?)";
+					+ " VALUES (?, ?, ?, ?, ?, ?, ?, ?)";
 			ps = txn.prepareStatement(sql);
-			ps.setInt(1, c.getInt());
-			ps.setString(2, k.getTransportId().getString());
+			ps.setInt(1, keySetId.getInt());
+			if (c == null) ps.setNull(2, INTEGER);
+			else ps.setInt(2, c.getInt());
+			ps.setString(3, k.getTransportId().getString());
 			// Previous rotation period
 			IncomingKeys inPrev = k.getPreviousIncomingKeys();
-			ps.setLong(3, inPrev.getRotationPeriod());
-			ps.setBytes(4, inPrev.getTagKey().getBytes());
-			ps.setBytes(5, inPrev.getHeaderKey().getBytes());
-			ps.setLong(6, inPrev.getWindowBase());
-			ps.setBytes(7, inPrev.getWindowBitmap());
+			ps.setLong(4, inPrev.getRotationPeriod());
+			ps.setBytes(5, inPrev.getTagKey().getBytes());
+			ps.setBytes(6, inPrev.getHeaderKey().getBytes());
+			ps.setLong(7, inPrev.getWindowBase());
+			ps.setBytes(8, inPrev.getWindowBitmap());
 			ps.addBatch();
 			// Current rotation period
 			IncomingKeys inCurr = k.getCurrentIncomingKeys();
-			ps.setLong(3, inCurr.getRotationPeriod());
-			ps.setBytes(4, inCurr.getTagKey().getBytes());
-			ps.setBytes(5, inCurr.getHeaderKey().getBytes());
-			ps.setLong(6, inCurr.getWindowBase());
-			ps.setBytes(7, inCurr.getWindowBitmap());
+			ps.setLong(4, inCurr.getRotationPeriod());
+			ps.setBytes(5, inCurr.getTagKey().getBytes());
+			ps.setBytes(6, inCurr.getHeaderKey().getBytes());
+			ps.setLong(7, inCurr.getWindowBase());
+			ps.setBytes(8, inCurr.getWindowBitmap());
 			ps.addBatch();
 			// Next rotation period
 			IncomingKeys inNext = k.getNextIncomingKeys();
-			ps.setLong(3, inNext.getRotationPeriod());
-			ps.setBytes(4, inNext.getTagKey().getBytes());
-			ps.setBytes(5, inNext.getHeaderKey().getBytes());
-			ps.setLong(6, inNext.getWindowBase());
-			ps.setBytes(7, inNext.getWindowBitmap());
+			ps.setLong(4, inNext.getRotationPeriod());
+			ps.setBytes(5, inNext.getTagKey().getBytes());
+			ps.setBytes(6, inNext.getHeaderKey().getBytes());
+			ps.setLong(7, inNext.getWindowBase());
+			ps.setBytes(8, inNext.getWindowBitmap());
 			ps.addBatch();
 			int[] batchAffected = ps.executeBatch();
 			if (batchAffected.length != 3) throw new DbStateException();
 			for (int rows : batchAffected)
 				if (rows != 1) throw new DbStateException();
 			ps.close();
-			// Store the outgoing keys
-			sql = "INSERT INTO outgoingKeys (contactId, transportId,"
-					+ " rotationPeriod, tagKey, headerKey, stream)"
-					+ " VALUES (?, ?, ?, ?, ?, ?)";
-			ps = txn.prepareStatement(sql);
-			ps.setInt(1, c.getInt());
-			ps.setString(2, k.getTransportId().getString());
-			OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
-			ps.setLong(3, outCurr.getRotationPeriod());
-			ps.setBytes(4, outCurr.getTagKey().getBytes());
-			ps.setBytes(5, outCurr.getHeaderKey().getBytes());
-			ps.setLong(6, outCurr.getStreamCounter());
-			int affected = ps.executeUpdate();
-			if (affected != 1) throw new DbStateException();
-			ps.close();
+			return keySetId;
 		} catch (SQLException e) {
+			tryToClose(rs);
 			tryToClose(ps);
 			throw new DbException(e);
 		}
@@ -2078,8 +2102,8 @@ abstract class JdbcDatabase implements Database<Connection> {
 	}
 
 	@Override
-	public Map<ContactId, TransportKeys> getTransportKeys(Connection txn,
-			TransportId t) throws DbException {
+	public Collection<KeySet> getTransportKeys(Connection txn, TransportId t)
+			throws DbException {
 		PreparedStatement ps = null;
 		ResultSet rs = null;
 		try {
@@ -2088,7 +2112,7 @@ abstract class JdbcDatabase implements Database<Connection> {
 					+ " base, bitmap"
 					+ " FROM incomingKeys"
 					+ " WHERE transportId = ?"
-					+ " ORDER BY contactId, rotationPeriod";
+					+ " ORDER BY keySetId, rotationPeriod";
 			ps = txn.prepareStatement(sql);
 			ps.setString(1, t.getString());
 			rs = ps.executeQuery();
@@ -2105,29 +2129,33 @@ abstract class JdbcDatabase implements Database<Connection> {
 			rs.close();
 			ps.close();
 			// Retrieve the outgoing keys in the same order
-			sql = "SELECT contactId, rotationPeriod, tagKey, headerKey, stream"
+			sql = "SELECT keySetId, contactId, rotationPeriod,"
+					+ " tagKey, headerKey, stream"
 					+ " FROM outgoingKeys"
 					+ " WHERE transportId = ?"
-					+ " ORDER BY contactId, rotationPeriod";
+					+ " ORDER BY keySetId";
 			ps = txn.prepareStatement(sql);
 			ps.setString(1, t.getString());
 			rs = ps.executeQuery();
-			Map<ContactId, TransportKeys> keys = new HashMap<>();
+			Collection<KeySet> keys = new ArrayList<>();
 			for (int i = 0; rs.next(); i++) {
 				// There should be three times as many incoming keys
 				if (inKeys.size() < (i + 1) * 3) throw new DbStateException();
-				ContactId contactId = new ContactId(rs.getInt(1));
-				long rotationPeriod = rs.getLong(2);
-				SecretKey tagKey = new SecretKey(rs.getBytes(3));
-				SecretKey headerKey = new SecretKey(rs.getBytes(4));
-				long streamCounter = rs.getLong(5);
+				KeySetId keySetId = new KeySetId(rs.getInt(1));
+				ContactId contactId = new ContactId(rs.getInt(2));
+				if (rs.wasNull()) contactId = null;
+				long rotationPeriod = rs.getLong(3);
+				SecretKey tagKey = new SecretKey(rs.getBytes(4));
+				SecretKey headerKey = new SecretKey(rs.getBytes(5));
+				long streamCounter = rs.getLong(6);
 				OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey,
 						rotationPeriod, streamCounter);
 				IncomingKeys inPrev = inKeys.get(i * 3);
 				IncomingKeys inCurr = inKeys.get(i * 3 + 1);
 				IncomingKeys inNext = inKeys.get(i * 3 + 2);
-				keys.put(contactId, new TransportKeys(t, inPrev, inCurr,
-						inNext, outCurr));
+				TransportKeys transportKeys = new TransportKeys(t, inPrev,
+						inCurr, inNext, outCurr);
+				keys.add(new KeySet(keySetId, contactId, transportKeys));
 			}
 			rs.close();
 			ps.close();
@@ -2791,18 +2819,18 @@ abstract class JdbcDatabase implements Database<Connection> {
 	}
 
 	@Override
-	public void setReorderingWindow(Connection txn, ContactId c, TransportId t,
+	public void setReorderingWindow(Connection txn, KeySetId k, TransportId t,
 			long rotationPeriod, long base, byte[] bitmap) throws DbException {
 		PreparedStatement ps = null;
 		try {
 			String sql = "UPDATE incomingKeys SET base = ?, bitmap = ?"
-					+ " WHERE contactId = ? AND transportId = ?"
+					+ " WHERE transportId = ? AND keySetId = ?"
 					+ " AND rotationPeriod = ?";
 			ps = txn.prepareStatement(sql);
 			ps.setLong(1, base);
 			ps.setBytes(2, bitmap);
-			ps.setInt(3, c.getInt());
-			ps.setString(4, t.getString());
+			ps.setString(3, t.getString());
+			ps.setInt(4, k.getInt());
 			ps.setLong(5, rotationPeriod);
 			int affected = ps.executeUpdate();
 			if (affected < 0 || affected > 1) throw new DbStateException();
@@ -2848,45 +2876,31 @@ abstract class JdbcDatabase implements Database<Connection> {
 	}
 
 	@Override
-	public void updateTransportKeys(Connection txn,
-			Map<ContactId, TransportKeys> keys) throws DbException {
+	public void updateTransportKeys(Connection txn, Collection<KeySet> keys)
+			throws DbException {
 		PreparedStatement ps = null;
 		try {
-			// Delete any existing incoming keys
-			String sql = "DELETE FROM incomingKeys"
-					+ " WHERE contactId = ?"
-					+ " AND transportId = ?";
+			// Delete any existing outgoing keys - this will also remove any
+			// incoming keys with the same key set ID
+			String sql = "DELETE FROM outgoingKeys WHERE keySetId = ?";
 			ps = txn.prepareStatement(sql);
-			for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
-				ps.setInt(1, e.getKey().getInt());
-				ps.setString(2, e.getValue().getTransportId().getString());
+			for (KeySet ks : keys) {
+				ps.setInt(1, ks.getKeySetId().getInt());
 				ps.addBatch();
 			}
 			int[] batchAffected = ps.executeBatch();
 			if (batchAffected.length != keys.size())
 				throw new DbStateException();
-			ps.close();
-			// Delete any existing outgoing keys
-			sql = "DELETE FROM outgoingKeys"
-					+ " WHERE contactId = ?"
-					+ " AND transportId = ?";
-			ps = txn.prepareStatement(sql);
-			for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
-				ps.setInt(1, e.getKey().getInt());
-				ps.setString(2, e.getValue().getTransportId().getString());
-				ps.addBatch();
-			}
-			batchAffected = ps.executeBatch();
-			if (batchAffected.length != keys.size())
-				throw new DbStateException();
+			for (int rows: batchAffected)
+				if (rows < 0) throw new DbStateException();
 			ps.close();
 		} catch (SQLException e) {
 			tryToClose(ps);
 			throw new DbException(e);
 		}
 		// Store the new keys
-		for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
-			addTransportKeys(txn, e.getKey(), e.getValue());
+		for (KeySet ks : keys) {
+			addTransportKeys(txn, ks.getContactId(), ks.getTransportKeys());
 		}
 	}
 }
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/transport/MutableKeySet.java b/bramble-core/src/main/java/org/briarproject/bramble/transport/MutableKeySet.java
new file mode 100644
index 0000000000000000000000000000000000000000..b55c5aef4f8551741c52e6a48d8630aebf86b29e
--- /dev/null
+++ b/bramble-core/src/main/java/org/briarproject/bramble/transport/MutableKeySet.java
@@ -0,0 +1,34 @@
+package org.briarproject.bramble.transport;
+
+import org.briarproject.bramble.api.contact.ContactId;
+import org.briarproject.bramble.api.transport.KeySetId;
+
+import javax.annotation.Nullable;
+
+public class MutableKeySet {
+
+	private final KeySetId keySetId;
+	@Nullable
+	private final ContactId contactId;
+	private final MutableTransportKeys transportKeys;
+
+	public MutableKeySet(KeySetId keySetId, @Nullable ContactId contactId,
+			MutableTransportKeys transportKeys) {
+		this.keySetId = keySetId;
+		this.contactId = contactId;
+		this.transportKeys = transportKeys;
+	}
+
+	public KeySetId getKeySetId() {
+		return keySetId;
+	}
+
+	@Nullable
+	public ContactId getContactId() {
+		return contactId;
+	}
+
+	public MutableTransportKeys getTransportKeys() {
+		return transportKeys;
+	}
+}
diff --git a/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java
index 60b48427ff18f07703a147f827ae39639c6f629e..1220beee580ea7af59c0b682af94546fd80e0068 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java
@@ -11,19 +11,25 @@ import org.briarproject.bramble.api.nullsafety.NotNullByDefault;
 import org.briarproject.bramble.api.plugin.TransportId;
 import org.briarproject.bramble.api.system.Clock;
 import org.briarproject.bramble.api.system.Scheduler;
+import org.briarproject.bramble.api.transport.KeySet;
+import org.briarproject.bramble.api.transport.KeySetId;
 import org.briarproject.bramble.api.transport.StreamContext;
 import org.briarproject.bramble.api.transport.TransportKeys;
 import org.briarproject.bramble.transport.ReorderingWindow.Change;
 
+import java.util.ArrayList;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.concurrent.Executor;
 import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.locks.ReentrantLock;
 import java.util.logging.Logger;
 
+import javax.annotation.Nullable;
 import javax.annotation.concurrent.ThreadSafe;
 
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
@@ -47,12 +53,13 @@ class TransportKeyManagerImpl implements TransportKeyManager {
 	private final Clock clock;
 	private final TransportId transportId;
 	private final long rotationPeriodLength;
-	private final ReentrantLock lock;
+	private final AtomicBoolean used = new AtomicBoolean(false);
+	private final ReentrantLock lock = new ReentrantLock();
 
 	// The following are locking: lock
-	private final Map<Bytes, TagContext> inContexts;
-	private final Map<ContactId, MutableOutgoingKeys> outContexts;
-	private final Map<ContactId, MutableTransportKeys> keys;
+	private final Collection<MutableKeySet> keys = new ArrayList<>();
+	private final Map<Bytes, TagContext> inContexts = new HashMap<>();
+	private final Map<ContactId, MutableKeySet> outContexts = new HashMap<>();
 
 	TransportKeyManagerImpl(DatabaseComponent db,
 			TransportCrypto transportCrypto, Executor dbExecutor,
@@ -65,20 +72,16 @@ class TransportKeyManagerImpl implements TransportKeyManager {
 		this.clock = clock;
 		this.transportId = transportId;
 		rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE;
-		lock = new ReentrantLock();
-		inContexts = new HashMap<>();
-		outContexts = new HashMap<>();
-		keys = new HashMap<>();
 	}
 
 	@Override
 	public void start(Transaction txn) throws DbException {
+		if (used.getAndSet(true)) throw new IllegalStateException();
 		long now = clock.currentTimeMillis();
 		lock.lock();
 		try {
 			// Load the transport keys from the DB
-			Map<ContactId, TransportKeys> loaded =
-					db.getTransportKeys(txn, transportId);
+			Collection<KeySet> loaded = db.getTransportKeys(txn, transportId);
 			// Rotate the keys to the current rotation period
 			RotationResult rotationResult = rotateKeys(loaded, now);
 			// Initialise mutable state for all contacts
@@ -93,41 +96,51 @@ class TransportKeyManagerImpl implements TransportKeyManager {
 		scheduleKeyRotation(now);
 	}
 
-	private RotationResult rotateKeys(Map<ContactId, TransportKeys> keys,
-			long now) {
+	private RotationResult rotateKeys(Collection<KeySet> keys, long now) {
 		RotationResult rotationResult = new RotationResult();
 		long rotationPeriod = now / rotationPeriodLength;
-		for (Entry<ContactId, TransportKeys> e : keys.entrySet()) {
-			ContactId c = e.getKey();
-			TransportKeys k = e.getValue();
+		for (KeySet ks : keys) {
+			TransportKeys k = ks.getTransportKeys();
 			TransportKeys k1 =
 					transportCrypto.rotateTransportKeys(k, rotationPeriod);
+			KeySet ks1 = new KeySet(ks.getKeySetId(), ks.getContactId(), k1);
 			if (k1.getRotationPeriod() > k.getRotationPeriod())
-				rotationResult.rotated.put(c, k1);
-			rotationResult.current.put(c, k1);
+				rotationResult.rotated.add(ks1);
+			rotationResult.current.add(ks1);
 		}
 		return rotationResult;
 	}
 
 	// Locking: lock
-	private void addKeys(Map<ContactId, TransportKeys> m) {
-		for (Entry<ContactId, TransportKeys> e : m.entrySet())
-			addKeys(e.getKey(), new MutableTransportKeys(e.getValue()));
+	private void addKeys(Collection<KeySet> keys) {
+		for (KeySet ks : keys) {
+			addKeys(ks.getKeySetId(), ks.getContactId(),
+					new MutableTransportKeys(ks.getTransportKeys()));
+		}
 	}
 
 	// Locking: lock
-	private void addKeys(ContactId c, MutableTransportKeys m) {
-		encodeTags(c, m.getPreviousIncomingKeys());
-		encodeTags(c, m.getCurrentIncomingKeys());
-		encodeTags(c, m.getNextIncomingKeys());
-		outContexts.put(c, m.getCurrentOutgoingKeys());
-		keys.put(c, m);
+	private void addKeys(KeySetId keySetId, @Nullable ContactId contactId,
+			MutableTransportKeys m) {
+		MutableKeySet ks = new MutableKeySet(keySetId, contactId, m);
+		keys.add(ks);
+		if (contactId != null) {
+			encodeTags(keySetId, contactId, m.getPreviousIncomingKeys());
+			encodeTags(keySetId, contactId, m.getCurrentIncomingKeys());
+			encodeTags(keySetId, contactId, m.getNextIncomingKeys());
+			// Use the outgoing keys with the highest key set ID
+			MutableKeySet old = outContexts.get(contactId);
+			if (old == null || old.getKeySetId().getInt() < keySetId.getInt())
+				outContexts.put(contactId, ks);
+		}
 	}
 
 	// Locking: lock
-	private void encodeTags(ContactId c, MutableIncomingKeys inKeys) {
+	private void encodeTags(KeySetId keySetId, ContactId contactId,
+			MutableIncomingKeys inKeys) {
 		for (long streamNumber : inKeys.getWindow().getUnseen()) {
-			TagContext tagCtx = new TagContext(c, inKeys, streamNumber);
+			TagContext tagCtx =
+					new TagContext(keySetId, contactId, inKeys, streamNumber);
 			byte[] tag = new byte[TAG_LENGTH];
 			transportCrypto.encodeTag(tag, inKeys.getTagKey(), PROTOCOL_VERSION,
 					streamNumber);
@@ -169,10 +182,10 @@ class TransportKeyManagerImpl implements TransportKeyManager {
 			// Rotate the keys to the current rotation period if necessary
 			rotationPeriod = clock.currentTimeMillis() / rotationPeriodLength;
 			k = transportCrypto.rotateTransportKeys(k, rotationPeriod);
-			// Initialise mutable state for the contact
-			addKeys(c, new MutableTransportKeys(k));
 			// Write the keys back to the DB
-			db.addTransportKeys(txn, c, k);
+			KeySetId keySetId = db.addTransportKeys(txn, c, k);
+			// Initialise mutable state for the contact
+			addKeys(keySetId, c, new MutableTransportKeys(k));
 		} finally {
 			lock.unlock();
 		}
@@ -183,12 +196,18 @@ class TransportKeyManagerImpl implements TransportKeyManager {
 		lock.lock();
 		try {
 			// Remove mutable state for the contact
-			Iterator<Entry<Bytes, TagContext>> it =
+			Iterator<Entry<Bytes, TagContext>> inContextsIter =
 					inContexts.entrySet().iterator();
-			while (it.hasNext())
-				if (it.next().getValue().contactId.equals(c)) it.remove();
+			while (inContextsIter.hasNext()) {
+				ContactId c1 = inContextsIter.next().getValue().contactId;
+				if (c1.equals(c)) inContextsIter.remove();
+			}
 			outContexts.remove(c);
-			keys.remove(c);
+			Iterator<MutableKeySet> keysIter = keys.iterator();
+			while (keysIter.hasNext()) {
+				ContactId c1 = keysIter.next().getContactId();
+				if (c1 != null && c1.equals(c)) keysIter.remove();
+			}
 		} finally {
 			lock.unlock();
 		}
@@ -200,8 +219,10 @@ class TransportKeyManagerImpl implements TransportKeyManager {
 		lock.lock();
 		try {
 			// Look up the outgoing keys for the contact
-			MutableOutgoingKeys outKeys = outContexts.get(c);
-			if (outKeys == null) return null;
+			MutableKeySet ks = outContexts.get(c);
+			if (ks == null) return null;
+			MutableOutgoingKeys outKeys =
+					ks.getTransportKeys().getCurrentOutgoingKeys();
 			if (outKeys.getStreamCounter() > MAX_32_BIT_UNSIGNED) return null;
 			// Create a stream context
 			StreamContext ctx = new StreamContext(c, transportId,
@@ -238,8 +259,9 @@ class TransportKeyManagerImpl implements TransportKeyManager {
 				byte[] addTag = new byte[TAG_LENGTH];
 				transportCrypto.encodeTag(addTag, inKeys.getTagKey(),
 						PROTOCOL_VERSION, streamNumber);
-				inContexts.put(new Bytes(addTag), new TagContext(
-						tagCtx.contactId, inKeys, streamNumber));
+				TagContext tagCtx1 = new TagContext(tagCtx.keySetId,
+						tagCtx.contactId, inKeys, streamNumber);
+				inContexts.put(new Bytes(addTag), tagCtx1);
 			}
 			// Remove tags for any stream numbers removed from the window
 			for (long streamNumber : change.getRemoved()) {
@@ -250,7 +272,7 @@ class TransportKeyManagerImpl implements TransportKeyManager {
 				inContexts.remove(new Bytes(removeTag));
 			}
 			// Write the window back to the DB
-			db.setReorderingWindow(txn, tagCtx.contactId, transportId,
+			db.setReorderingWindow(txn, tagCtx.keySetId, transportId,
 					inKeys.getRotationPeriod(), window.getBase(),
 					window.getBitmap());
 			return ctx;
@@ -264,9 +286,11 @@ class TransportKeyManagerImpl implements TransportKeyManager {
 		lock.lock();
 		try {
 			// Rotate the keys to the current rotation period
-			Map<ContactId, TransportKeys> snapshot = new HashMap<>();
-			for (Entry<ContactId, MutableTransportKeys> e : keys.entrySet())
-				snapshot.put(e.getKey(), e.getValue().snapshot());
+			Collection<KeySet> snapshot = new ArrayList<>(keys.size());
+			for (MutableKeySet ks : keys) {
+				snapshot.add(new KeySet(ks.getKeySetId(), ks.getContactId(),
+						ks.getTransportKeys().snapshot()));
+			}
 			RotationResult rotationResult = rotateKeys(snapshot, now);
 			// Rebuild the mutable state for all contacts
 			inContexts.clear();
@@ -285,12 +309,14 @@ class TransportKeyManagerImpl implements TransportKeyManager {
 
 	private static class TagContext {
 
+		private final KeySetId keySetId;
 		private final ContactId contactId;
 		private final MutableIncomingKeys inKeys;
 		private final long streamNumber;
 
-		private TagContext(ContactId contactId, MutableIncomingKeys inKeys,
-				long streamNumber) {
+		private TagContext(KeySetId keySetId, ContactId contactId,
+				MutableIncomingKeys inKeys, long streamNumber) {
+			this.keySetId = keySetId;
 			this.contactId = contactId;
 			this.inKeys = inKeys;
 			this.streamNumber = streamNumber;
@@ -299,11 +325,7 @@ class TransportKeyManagerImpl implements TransportKeyManager {
 
 	private static class RotationResult {
 
-		private final Map<ContactId, TransportKeys> current, rotated;
-
-		private RotationResult() {
-			current = new HashMap<>();
-			rotated = new HashMap<>();
-		}
+		private final Collection<KeySet> current = new ArrayList<>();
+		private final Collection<KeySet> rotated = new ArrayList<>();
 	}
 }
diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java
index 0ed4574618fc17e8eac6f0dcb59cb395f206e78d..eb31247ce074642cd6c9fcc29de75280def9151c 100644
--- a/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java
+++ b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java
@@ -44,6 +44,8 @@ import org.briarproject.bramble.api.sync.event.MessageToRequestEvent;
 import org.briarproject.bramble.api.sync.event.MessagesAckedEvent;
 import org.briarproject.bramble.api.sync.event.MessagesSentEvent;
 import org.briarproject.bramble.api.transport.IncomingKeys;
+import org.briarproject.bramble.api.transport.KeySet;
+import org.briarproject.bramble.api.transport.KeySetId;
 import org.briarproject.bramble.api.transport.OutgoingKeys;
 import org.briarproject.bramble.api.transport.TransportKeys;
 import org.briarproject.bramble.test.BrambleMockTestCase;
@@ -55,12 +57,10 @@ import org.junit.Test;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
-import java.util.Map;
 import java.util.concurrent.atomic.AtomicReference;
 
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.singletonList;
-import static java.util.Collections.singletonMap;
 import static org.briarproject.bramble.api.sync.Group.Visibility.INVISIBLE;
 import static org.briarproject.bramble.api.sync.Group.Visibility.SHARED;
 import static org.briarproject.bramble.api.sync.Group.Visibility.VISIBLE;
@@ -100,6 +100,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
 	private final int maxLatency;
 	private final ContactId contactId;
 	private final Contact contact;
+	private final KeySetId keySetId;
 
 	public DatabaseComponentImplTest() {
 		clientId = new ClientId(getRandomString(123));
@@ -121,6 +122,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
 		contactId = new ContactId(234);
 		contact = new Contact(contactId, author, localAuthor.getId(),
 				true, true);
+		keySetId = new KeySetId(345);
 	}
 
 	private DatabaseComponent createDatabaseComponent(Database<Object> database,
@@ -282,11 +284,11 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
 			throws Exception {
 		context.checking(new Expectations() {{
 			// Check whether the contact is in the DB (which it's not)
-			exactly(18).of(database).startTransaction();
+			exactly(17).of(database).startTransaction();
 			will(returnValue(txn));
-			exactly(18).of(database).containsContact(txn, contactId);
+			exactly(17).of(database).containsContact(txn, contactId);
 			will(returnValue(false));
-			exactly(18).of(database).abortTransaction(txn);
+			exactly(17).of(database).abortTransaction(txn);
 		}});
 		DatabaseComponent db = createDatabaseComponent(database, eventBus,
 				shutdown);
@@ -454,17 +456,6 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
 			db.endTransaction(transaction);
 		}
 
-		transaction = db.startTransaction(false);
-		try {
-			db.setReorderingWindow(transaction, contactId, transportId, 0, 0,
-					new byte[REORDERING_WINDOW_SIZE / 8]);
-			fail();
-		} catch (NoSuchContactException expected) {
-			// Expected
-		} finally {
-			db.endTransaction(transaction);
-		}
-
 		transaction = db.startTransaction(false);
 		try {
 			db.setGroupVisibility(transaction, contactId, groupId, SHARED);
@@ -779,7 +770,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
 			// Check whether the transport is in the DB (which it's not)
 			exactly(4).of(database).startTransaction();
 			will(returnValue(txn));
-			exactly(2).of(database).containsContact(txn, contactId);
+			oneOf(database).containsContact(txn, contactId);
 			will(returnValue(true));
 			exactly(4).of(database).containsTransport(txn, transportId);
 			will(returnValue(false));
@@ -830,7 +821,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
 
 		transaction = db.startTransaction(false);
 		try {
-			db.setReorderingWindow(transaction, contactId, transportId, 0, 0,
+			db.setReorderingWindow(transaction, keySetId, transportId, 0, 0,
 					new byte[REORDERING_WINDOW_SIZE / 8]);
 			fail();
 		} catch (NoSuchTransportException expected) {
@@ -1303,15 +1294,13 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
 	@Test
 	public void testTransportKeys() throws Exception {
 		TransportKeys transportKeys = createTransportKeys();
-		Map<ContactId, TransportKeys> keys =
-				singletonMap(contactId, transportKeys);
+		Collection<KeySet> keys =
+				singletonList(new KeySet(keySetId, contactId, transportKeys));
 		context.checking(new Expectations() {{
 			// startTransaction()
 			oneOf(database).startTransaction();
 			will(returnValue(txn));
 			// updateTransportKeys()
-			oneOf(database).containsContact(txn, contactId);
-			will(returnValue(true));
 			oneOf(database).containsTransport(txn, transportId);
 			will(returnValue(true));
 			oneOf(database).updateTransportKeys(txn, keys);
diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java
index d81cfd85149cdc1ec6e2a53d0b1a4de2b4afeb7a..4237af191b8123ba70c3deacb79e21b0124ef725 100644
--- a/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java
+++ b/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java
@@ -19,6 +19,8 @@ import org.briarproject.bramble.api.sync.MessageStatus;
 import org.briarproject.bramble.api.sync.ValidationManager.State;
 import org.briarproject.bramble.api.system.Clock;
 import org.briarproject.bramble.api.transport.IncomingKeys;
+import org.briarproject.bramble.api.transport.KeySet;
+import org.briarproject.bramble.api.transport.KeySetId;
 import org.briarproject.bramble.api.transport.OutgoingKeys;
 import org.briarproject.bramble.api.transport.TransportKeys;
 import org.briarproject.bramble.system.SystemClock;
@@ -34,7 +36,6 @@ import java.sql.Connection;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
-import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
@@ -42,6 +43,10 @@ import java.util.Random;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.atomic.AtomicBoolean;
 
+import static java.util.Collections.emptyList;
+import static java.util.Collections.emptyMap;
+import static java.util.Collections.singletonList;
+import static java.util.Collections.singletonMap;
 import static java.util.concurrent.TimeUnit.SECONDS;
 import static org.briarproject.bramble.api.db.Metadata.REMOVE;
 import static org.briarproject.bramble.api.sync.Group.Visibility.INVISIBLE;
@@ -86,6 +91,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 	private final Message message;
 	private final TransportId transportId;
 	private final ContactId contactId;
+	private final KeySetId keySetId;
 
 	JdbcDatabaseTest() throws Exception {
 		groupId = new GroupId(getRandomId());
@@ -101,6 +107,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 		message = new Message(messageId, groupId, timestamp, raw);
 		transportId = new TransportId("id");
 		contactId = new ContactId(1);
+		keySetId = new KeySetId(1);
 	}
 
 	protected abstract JdbcDatabase createDatabase(DatabaseConfig config,
@@ -190,9 +197,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 		// The contact has not seen the message, so it should be sendable
 		Collection<MessageId> ids =
 				db.getMessagesToSend(txn, contactId, ONE_MEGABYTE);
-		assertEquals(Collections.singletonList(messageId), ids);
+		assertEquals(singletonList(messageId), ids);
 		ids = db.getMessagesToOffer(txn, contactId, 100);
-		assertEquals(Collections.singletonList(messageId), ids);
+		assertEquals(singletonList(messageId), ids);
 
 		// Changing the status to seen = true should make the message unsendable
 		db.raiseSeenFlag(txn, contactId, messageId);
@@ -228,9 +235,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 		// Marking the message delivered should make it sendable
 		db.setMessageState(txn, messageId, DELIVERED);
 		ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE);
-		assertEquals(Collections.singletonList(messageId), ids);
+		assertEquals(singletonList(messageId), ids);
 		ids = db.getMessagesToOffer(txn, contactId, 100);
-		assertEquals(Collections.singletonList(messageId), ids);
+		assertEquals(singletonList(messageId), ids);
 
 		// Marking the message invalid should make it unsendable
 		db.setMessageState(txn, messageId, INVALID);
@@ -279,9 +286,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 		// Sharing the group should make the message sendable
 		db.setGroupVisibility(txn, contactId, groupId, true);
 		ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE);
-		assertEquals(Collections.singletonList(messageId), ids);
+		assertEquals(singletonList(messageId), ids);
 		ids = db.getMessagesToOffer(txn, contactId, 100);
-		assertEquals(Collections.singletonList(messageId), ids);
+		assertEquals(singletonList(messageId), ids);
 
 		// Unsharing the group should make the message unsendable
 		db.setGroupVisibility(txn, contactId, groupId, false);
@@ -324,9 +331,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 		// Sharing the message should make it sendable
 		db.setMessageShared(txn, messageId);
 		ids = db.getMessagesToSend(txn, contactId, ONE_MEGABYTE);
-		assertEquals(Collections.singletonList(messageId), ids);
+		assertEquals(singletonList(messageId), ids);
 		ids = db.getMessagesToOffer(txn, contactId, 100);
-		assertEquals(Collections.singletonList(messageId), ids);
+		assertEquals(singletonList(messageId), ids);
 
 		db.commitTransaction(txn);
 		db.close();
@@ -352,7 +359,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 
 		// The message is just the right size to send
 		ids = db.getMessagesToSend(txn, contactId, size);
-		assertEquals(Collections.singletonList(messageId), ids);
+		assertEquals(singletonList(messageId), ids);
 
 		db.commitTransaction(txn);
 		db.close();
@@ -384,7 +391,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 		db.lowerAckFlag(txn, contactId, Arrays.asList(messageId, messageId1));
 
 		// Both message IDs should have been removed
-		assertEquals(Collections.emptyList(), db.getMessagesToAck(txn,
+		assertEquals(emptyList(), db.getMessagesToAck(txn,
 				contactId, 1234));
 
 		// Raise the ack flag again
@@ -415,7 +422,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 		// Retrieve the message from the database and mark it as sent
 		Collection<MessageId> ids = db.getMessagesToSend(txn, contactId,
 				ONE_MEGABYTE);
-		assertEquals(Collections.singletonList(messageId), ids);
+		assertEquals(singletonList(messageId), ids);
 		db.updateExpiryTime(txn, contactId, messageId, Integer.MAX_VALUE);
 
 		// The message should no longer be sendable
@@ -626,31 +633,31 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 
 		// The group should not be visible to the contact
 		assertEquals(INVISIBLE, db.getGroupVisibility(txn, contactId, groupId));
-		assertEquals(Collections.emptyMap(),
+		assertEquals(emptyMap(),
 				db.getGroupVisibility(txn, groupId));
 
 		// Make the group visible to the contact
 		db.addGroupVisibility(txn, contactId, groupId, false);
 		assertEquals(VISIBLE, db.getGroupVisibility(txn, contactId, groupId));
-		assertEquals(Collections.singletonMap(contactId, false),
+		assertEquals(singletonMap(contactId, false),
 				db.getGroupVisibility(txn, groupId));
 
 		// Share the group with the contact
 		db.setGroupVisibility(txn, contactId, groupId, true);
 		assertEquals(SHARED, db.getGroupVisibility(txn, contactId, groupId));
-		assertEquals(Collections.singletonMap(contactId, true),
+		assertEquals(singletonMap(contactId, true),
 				db.getGroupVisibility(txn, groupId));
 
 		// Unshare the group with the contact
 		db.setGroupVisibility(txn, contactId, groupId, false);
 		assertEquals(VISIBLE, db.getGroupVisibility(txn, contactId, groupId));
-		assertEquals(Collections.singletonMap(contactId, false),
+		assertEquals(singletonMap(contactId, false),
 				db.getGroupVisibility(txn, groupId));
 
 		// Make the group invisible again
 		db.removeGroupVisibility(txn, contactId, groupId);
 		assertEquals(INVISIBLE, db.getGroupVisibility(txn, contactId, groupId));
-		assertEquals(Collections.emptyMap(),
+		assertEquals(emptyMap(),
 				db.getGroupVisibility(txn, groupId));
 
 		db.commitTransaction(txn);
@@ -665,24 +672,22 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 		Connection txn = db.startTransaction();
 
 		// Initially there should be no transport keys in the database
-		assertEquals(Collections.emptyMap(),
-				db.getTransportKeys(txn, transportId));
+		assertEquals(emptyList(), db.getTransportKeys(txn, transportId));
 
 		// Add the contact, the transport and the transport keys
 		db.addLocalAuthor(txn, localAuthor);
 		assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(),
 				true, true));
 		db.addTransport(txn, transportId, 123);
-		db.addTransportKeys(txn, contactId, keys);
+		assertEquals(keySetId, db.addTransportKeys(txn, contactId, keys));
 
 		// Retrieve the transport keys
-		Map<ContactId, TransportKeys> newKeys =
-				db.getTransportKeys(txn, transportId);
+		Collection<KeySet> newKeys = db.getTransportKeys(txn, transportId);
 		assertEquals(1, newKeys.size());
-		Entry<ContactId, TransportKeys> e =
-				newKeys.entrySet().iterator().next();
-		assertEquals(contactId, e.getKey());
-		TransportKeys k = e.getValue();
+		KeySet ks = newKeys.iterator().next();
+		assertEquals(keySetId, ks.getKeySetId());
+		assertEquals(contactId, ks.getContactId());
+		TransportKeys k = ks.getTransportKeys();
 		assertEquals(transportId, k.getTransportId());
 		assertKeysEquals(keys.getPreviousIncomingKeys(),
 				k.getPreviousIncomingKeys());
@@ -695,8 +700,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 
 		// Removing the contact should remove the transport keys
 		db.removeContact(txn, contactId);
-		assertEquals(Collections.emptyMap(),
-				db.getTransportKeys(txn, transportId));
+		assertEquals(emptyList(), db.getTransportKeys(txn, transportId));
 
 		db.commitTransaction(txn);
 		db.close();
@@ -735,18 +739,18 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 		assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(),
 				true, true));
 		db.addTransport(txn, transportId, 123);
-		db.updateTransportKeys(txn, Collections.singletonMap(contactId, keys));
+		db.updateTransportKeys(txn,
+				singletonList(new KeySet(keySetId, contactId, keys)));
 
 		// Increment the stream counter twice and retrieve the transport keys
 		db.incrementStreamCounter(txn, contactId, transportId, rotationPeriod);
 		db.incrementStreamCounter(txn, contactId, transportId, rotationPeriod);
-		Map<ContactId, TransportKeys> newKeys =
-				db.getTransportKeys(txn, transportId);
+		Collection<KeySet> newKeys = db.getTransportKeys(txn, transportId);
 		assertEquals(1, newKeys.size());
-		Entry<ContactId, TransportKeys> e =
-				newKeys.entrySet().iterator().next();
-		assertEquals(contactId, e.getKey());
-		TransportKeys k = e.getValue();
+		KeySet ks = newKeys.iterator().next();
+		assertEquals(keySetId, ks.getKeySetId());
+		assertEquals(contactId, ks.getContactId());
+		TransportKeys k = ks.getTransportKeys();
 		assertEquals(transportId, k.getTransportId());
 		OutgoingKeys outCurr = k.getCurrentOutgoingKeys();
 		assertEquals(rotationPeriod, outCurr.getRotationPeriod());
@@ -771,19 +775,19 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 		assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(),
 				true, true));
 		db.addTransport(txn, transportId, 123);
-		db.updateTransportKeys(txn, Collections.singletonMap(contactId, keys));
+		db.updateTransportKeys(txn,
+				singletonList(new KeySet(keySetId, contactId, keys)));
 
 		// Update the reordering window and retrieve the transport keys
 		new Random().nextBytes(bitmap);
-		db.setReorderingWindow(txn, contactId, transportId, rotationPeriod,
+		db.setReorderingWindow(txn, keySetId, transportId, rotationPeriod,
 				base + 1, bitmap);
-		Map<ContactId, TransportKeys> newKeys =
-				db.getTransportKeys(txn, transportId);
+		Collection<KeySet> newKeys = db.getTransportKeys(txn, transportId);
 		assertEquals(1, newKeys.size());
-		Entry<ContactId, TransportKeys> e =
-				newKeys.entrySet().iterator().next();
-		assertEquals(contactId, e.getKey());
-		TransportKeys k = e.getValue();
+		KeySet ks = newKeys.iterator().next();
+		assertEquals(keySetId, ks.getKeySetId());
+		assertEquals(contactId, ks.getContactId());
+		TransportKeys k = ks.getTransportKeys();
 		assertEquals(transportId, k.getTransportId());
 		IncomingKeys inCurr = k.getCurrentIncomingKeys();
 		assertEquals(rotationPeriod, inCurr.getRotationPeriod());
@@ -830,18 +834,18 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 		db.addLocalAuthor(txn, localAuthor);
 		Collection<ContactId> contacts =
 				db.getContacts(txn, localAuthor.getId());
-		assertEquals(Collections.emptyList(), contacts);
+		assertEquals(emptyList(), contacts);
 
 		// Add a contact associated with the local author
 		assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(),
 				true, true));
 		contacts = db.getContacts(txn, localAuthor.getId());
-		assertEquals(Collections.singletonList(contactId), contacts);
+		assertEquals(singletonList(contactId), contacts);
 
 		// Remove the local author - the contact should be removed
 		db.removeLocalAuthor(txn, localAuthor.getId());
 		contacts = db.getContacts(txn, localAuthor.getId());
-		assertEquals(Collections.emptyList(), contacts);
+		assertEquals(emptyList(), contacts);
 		assertFalse(db.containsContact(txn, contactId));
 
 		db.commitTransaction(txn);
@@ -1560,9 +1564,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 		// The message should be sendable
 		Collection<MessageId> ids = db.getMessagesToSend(txn, contactId,
 				ONE_MEGABYTE);
-		assertEquals(Collections.singletonList(messageId), ids);
+		assertEquals(singletonList(messageId), ids);
 		ids = db.getMessagesToOffer(txn, contactId, 100);
-		assertEquals(Collections.singletonList(messageId), ids);
+		assertEquals(singletonList(messageId), ids);
 
 		// The raw message should not be null
 		assertNotNull(db.getRawMessage(txn, messageId));
diff --git a/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java
index 4dc1f980240d9d91daa57860c5da1ea40d47b93d..213400f1e6096181c5c18a8fdc3bd437d64a4c6e 100644
--- a/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java
+++ b/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java
@@ -8,6 +8,8 @@ import org.briarproject.bramble.api.db.Transaction;
 import org.briarproject.bramble.api.plugin.TransportId;
 import org.briarproject.bramble.api.system.Clock;
 import org.briarproject.bramble.api.transport.IncomingKeys;
+import org.briarproject.bramble.api.transport.KeySet;
+import org.briarproject.bramble.api.transport.KeySetId;
 import org.briarproject.bramble.api.transport.OutgoingKeys;
 import org.briarproject.bramble.api.transport.StreamContext;
 import org.briarproject.bramble.api.transport.TransportKeys;
@@ -22,14 +24,13 @@ import org.junit.Test;
 
 import java.util.ArrayList;
 import java.util.Collection;
-import java.util.Collections;
-import java.util.LinkedHashMap;
 import java.util.List;
-import java.util.Map;
 import java.util.Random;
 import java.util.concurrent.Executor;
 import java.util.concurrent.ScheduledExecutorService;
 
+import static java.util.Arrays.asList;
+import static java.util.Collections.singletonList;
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
 import static org.briarproject.bramble.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE;
 import static org.briarproject.bramble.api.transport.TransportConstants.PROTOCOL_VERSION;
@@ -55,6 +56,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
 	private final long rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE;
 	private final ContactId contactId = new ContactId(123);
 	private final ContactId contactId1 = new ContactId(234);
+	private final KeySetId keySetId = new KeySetId(345);
 	private final SecretKey tagKey = TestUtils.getSecretKey();
 	private final SecretKey headerKey = TestUtils.getSecretKey();
 	private final SecretKey masterKey = TestUtils.getSecretKey();
@@ -62,11 +64,12 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
 
 	@Test
 	public void testKeysAreRotatedAtStartup() throws Exception {
-		Map<ContactId, TransportKeys> loaded = new LinkedHashMap<>();
 		TransportKeys shouldRotate = createTransportKeys(900, 0);
 		TransportKeys shouldNotRotate = createTransportKeys(1000, 0);
-		loaded.put(contactId, shouldRotate);
-		loaded.put(contactId1, shouldNotRotate);
+		Collection<KeySet> loaded = asList(
+				new KeySet(keySetId, contactId, shouldRotate),
+				new KeySet(keySetId, contactId1, shouldNotRotate)
+		);
 		TransportKeys rotated = createTransportKeys(1000, 0);
 		Transaction txn = new Transaction(null, false);
 
@@ -91,7 +94,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
 			}
 			// Save the keys that were rotated
 			oneOf(db).updateTransportKeys(txn,
-					Collections.singletonMap(contactId, rotated));
+					singletonList(new KeySet(keySetId, contactId, rotated)));
 			// Schedule key rotation at the start of the next rotation period
 			oneOf(scheduler).schedule(with(any(Runnable.class)),
 					with(rotationPeriodLength - 1), with(MILLISECONDS));
@@ -129,6 +132,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
 			}
 			// Save the keys
 			oneOf(db).addTransportKeys(txn, contactId, rotated);
+			will(returnValue(keySetId));
 		}});
 
 		TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
@@ -179,6 +183,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
 			will(returnValue(transportKeys));
 			// Save the keys
 			oneOf(db).addTransportKeys(txn, contactId, transportKeys);
+			will(returnValue(keySetId));
 		}});
 
 		TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
@@ -218,6 +223,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
 			will(returnValue(transportKeys));
 			// Save the keys
 			oneOf(db).addTransportKeys(txn, contactId, transportKeys);
+			will(returnValue(keySetId));
 			// Increment the stream counter
 			oneOf(db).incrementStreamCounter(txn, contactId, transportId, 1000);
 		}});
@@ -268,6 +274,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
 			will(returnValue(transportKeys));
 			// Save the keys
 			oneOf(db).addTransportKeys(txn, contactId, transportKeys);
+			will(returnValue(keySetId));
 		}});
 
 		TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
@@ -308,13 +315,14 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
 			will(returnValue(transportKeys));
 			// Save the keys
 			oneOf(db).addTransportKeys(txn, contactId, transportKeys);
+			will(returnValue(keySetId));
 			// Encode a new tag after sliding the window
 			oneOf(transportCrypto).encodeTag(with(any(byte[].class)),
 					with(tagKey), with(PROTOCOL_VERSION),
 					with((long) REORDERING_WINDOW_SIZE));
 			will(new EncodeTagAction(tags));
 			// Save the reordering window (previous rotation period, base 1)
-			oneOf(db).setReorderingWindow(txn, contactId, transportId, 999,
+			oneOf(db).setReorderingWindow(txn, keySetId, transportId, 999,
 					1, new byte[REORDERING_WINDOW_SIZE / 8]);
 		}});
 
@@ -345,8 +353,8 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
 	@Test
 	public void testKeysAreRotatedToCurrentPeriod() throws Exception {
 		TransportKeys transportKeys = createTransportKeys(1000, 0);
-		Map<ContactId, TransportKeys> loaded =
-				Collections.singletonMap(contactId, transportKeys);
+		Collection<KeySet> loaded =
+				singletonList(new KeySet(keySetId, contactId, transportKeys));
 		TransportKeys rotated = createTransportKeys(1001, 0);
 		Transaction txn = new Transaction(null, false);
 		Transaction txn1 = new Transaction(null, false);
@@ -393,7 +401,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
 			}
 			// Save the keys that were rotated
 			oneOf(db).updateTransportKeys(txn1,
-					Collections.singletonMap(contactId, rotated));
+					singletonList(new KeySet(keySetId, contactId, rotated)));
 			// Schedule key rotation at the start of the next rotation period
 			oneOf(scheduler).schedule(with(any(Runnable.class)),
 					with(rotationPeriodLength), with(MILLISECONDS));