diff --git a/briar-core/src/org/briarproject/transport/TransportKeyManager.java b/briar-core/src/org/briarproject/transport/TransportKeyManager.java
index b380765787d532ff011c9d021e64581d80d5e91e..ffc8c3bbfc00cf8223481bb0a15dbade4ebbfa63 100644
--- a/briar-core/src/org/briarproject/transport/TransportKeyManager.java
+++ b/briar-core/src/org/briarproject/transport/TransportKeyManager.java
@@ -27,7 +27,7 @@ import static org.briarproject.api.transport.TransportConstants.MAX_CLOCK_DIFFER
 import static org.briarproject.api.transport.TransportConstants.TAG_LENGTH;
 import static org.briarproject.util.ByteUtils.MAX_32_BIT_UNSIGNED;
 
-class TransportKeyManager extends TimerTask {
+class TransportKeyManager {
 
 	private static final Logger LOG =
 			Logger.getLogger(TransportKeyManager.class.getName());
@@ -97,21 +97,14 @@ class TransportKeyManager extends TimerTask {
 			for (Entry<ContactId, TransportKeys> e : current.entrySet())
 				addKeys(e.getKey(), new MutableTransportKeys(e.getValue()));
 			// Write any rotated keys back to the DB
-			Transaction txn = db.startTransaction();
-			try {
-				db.updateTransportKeys(txn, rotated);
-				txn.setComplete();
-			} finally {
-				db.endTransaction(txn);
-			}
+			updateTransportKeys(rotated);
 		} catch (DbException e) {
 			if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
 		} finally {
 			lock.unlock();
 		}
 		// Schedule the next key rotation
-		long delay = rotationPeriodLength - now % rotationPeriodLength;
-		timer.schedule(this, delay);
+		scheduleKeyRotation(now);
 	}
 
 	// Locking: lock
@@ -133,6 +126,29 @@ class TransportKeyManager extends TimerTask {
 		}
 	}
 
+	private void updateTransportKeys(Map<ContactId, TransportKeys> rotated)
+		throws DbException {
+		if (!rotated.isEmpty()) {
+			Transaction txn = db.startTransaction();
+			try {
+				db.updateTransportKeys(txn, rotated);
+				txn.setComplete();
+			} finally {
+				db.endTransaction(txn);
+			}
+		}
+	}
+
+	private void scheduleKeyRotation(long now) {
+		TimerTask task = new TimerTask() {
+			public void run() {
+				rotateKeys();
+			}
+		};
+		long delay = rotationPeriodLength - now % rotationPeriodLength;
+		timer.schedule(task, delay);
+	}
+
 	void addContact(ContactId c, SecretKey master, long timestamp,
 			boolean alice) {
 		lock.lock();
@@ -230,6 +246,7 @@ class TransportKeyManager extends TimerTask {
 			}
 			// Remove tags for any stream numbers removed from the window
 			for (long streamNumber : change.getRemoved()) {
+				if (streamNumber == tagCtx.streamNumber) continue;
 				byte[] removeTag = new byte[TAG_LENGTH];
 				crypto.encodeTag(removeTag, inKeys.getTagKey(), streamNumber);
 				inContexts.remove(new Bytes(removeTag));
@@ -253,8 +270,7 @@ class TransportKeyManager extends TimerTask {
 		}
 	}
 
-	@Override
-	public void run() {
+	private void rotateKeys() {
 		long now = clock.currentTimeMillis();
 		lock.lock();
 		try {
@@ -280,21 +296,14 @@ class TransportKeyManager extends TimerTask {
 			for (Entry<ContactId, TransportKeys> e : current.entrySet())
 				addKeys(e.getKey(), new MutableTransportKeys(e.getValue()));
 			// Write any rotated keys back to the DB
-			Transaction txn = db.startTransaction();
-			try {
-				db.updateTransportKeys(txn, rotated);
-				txn.setComplete();
-			} finally {
-				db.endTransaction(txn);
-			}
+			updateTransportKeys(rotated);
 		} catch (DbException e) {
 			if (LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e);
 		} finally {
 			lock.unlock();
 		}
 		// Schedule the next key rotation
-		long delay = rotationPeriodLength - now % rotationPeriodLength;
-		timer.schedule(this, delay);
+		scheduleKeyRotation(now);
 	}
 
 	private static class TagContext {
diff --git a/briar-tests/src/org/briarproject/transport/ReorderingWindowTest.java b/briar-tests/src/org/briarproject/transport/ReorderingWindowTest.java
index 85ef23a6e86b858ce85cd1059a6d98187d5077cc..1f18cbe5e363d03800b6b65e984791f98f49f51f 100644
--- a/briar-tests/src/org/briarproject/transport/ReorderingWindowTest.java
+++ b/briar-tests/src/org/briarproject/transport/ReorderingWindowTest.java
@@ -1,22 +1,8 @@
 package org.briarproject.transport;
 
 import org.briarproject.BriarTestCase;
-import org.briarproject.api.transport.TransportConstants;
 import org.briarproject.transport.ReorderingWindow.Change;
-import org.junit.Assert;
 import org.junit.Test;
-import org.briarproject.BriarTestCase;
-import org.junit.Test;
-
-import java.util.Collection;
-
-import static org.briarproject.api.transport.TransportConstants.REORDERING_WINDOW_SIZE;
-import static org.briarproject.util.ByteUtils.MAX_32_BIT_UNSIGNED;
-import static org.junit.Assert.assertArrayEquals;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
 
 import java.util.Arrays;
 import java.util.Collections;
@@ -24,6 +10,7 @@ import java.util.Random;
 
 import static org.briarproject.api.transport.TransportConstants.REORDERING_WINDOW_SIZE;
 import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
 
 public class ReorderingWindowTest extends BriarTestCase {
 
@@ -46,7 +33,8 @@ public class ReorderingWindowTest extends BriarTestCase {
 		Change change = window.setSeen(0L);
 		// The window should slide by one element
 		assertEquals(1L, window.getBase());
-		assertEquals(Collections.singletonList((long) REORDERING_WINDOW_SIZE), change.getAdded());
+		assertEquals(Collections.singletonList((long) REORDERING_WINDOW_SIZE),
+				change.getAdded());
 		assertEquals(Collections.singletonList(0L), change.getRemoved());
 		// All elements in the window should be unseen
 		assertArrayEquals(bitmap, window.getBitmap());
@@ -76,7 +64,8 @@ public class ReorderingWindowTest extends BriarTestCase {
 		Change change = window.setSeen(aboveMidpoint);
 		// The window should slide by one element
 		assertEquals(1L, window.getBase());
-		assertEquals(Collections.singletonList((long) REORDERING_WINDOW_SIZE), change.getAdded());
+		assertEquals(Collections.singletonList((long) REORDERING_WINDOW_SIZE),
+				change.getAdded());
 		assertEquals(Arrays.asList(0L, aboveMidpoint), change.getRemoved());
 		// The highest element below the midpoint should be seen
 		bitmap[bitmap.length / 2 - 1] = (byte) 0x01; // 0000 0001
diff --git a/briar-tests/src/org/briarproject/transport/TransportKeyManagerTest.java b/briar-tests/src/org/briarproject/transport/TransportKeyManagerTest.java
index 8183864059dec3645d9c9893dedac0331bfbb8ff..2422bb45e3ef87dc3b0acf92eee502c8368fec6d 100644
--- a/briar-tests/src/org/briarproject/transport/TransportKeyManagerTest.java
+++ b/briar-tests/src/org/briarproject/transport/TransportKeyManagerTest.java
@@ -1,14 +1,508 @@
 package org.briarproject.transport;
 
 import org.briarproject.BriarTestCase;
+import org.briarproject.TestUtils;
+import org.briarproject.api.TransportId;
+import org.briarproject.api.contact.ContactId;
+import org.briarproject.api.crypto.CryptoComponent;
+import org.briarproject.api.crypto.SecretKey;
+import org.briarproject.api.db.DatabaseComponent;
+import org.briarproject.api.db.Transaction;
+import org.briarproject.api.system.Clock;
+import org.briarproject.api.system.Timer;
+import org.briarproject.api.transport.IncomingKeys;
+import org.briarproject.api.transport.OutgoingKeys;
+import org.briarproject.api.transport.StreamContext;
+import org.briarproject.api.transport.TransportKeys;
+import org.hamcrest.Description;
+import org.jmock.Expectations;
+import org.jmock.Mockery;
+import org.jmock.api.Action;
+import org.jmock.api.Invocation;
 import org.junit.Test;
 
-import static org.junit.Assert.fail;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.TimerTask;
+
+import static org.briarproject.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE;
+import static org.briarproject.api.transport.TransportConstants.REORDERING_WINDOW_SIZE;
+import static org.briarproject.api.transport.TransportConstants.TAG_LENGTH;
+import static org.briarproject.util.ByteUtils.MAX_32_BIT_UNSIGNED;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
 
 public class TransportKeyManagerTest extends BriarTestCase {
 
+	private final TransportId transportId = new TransportId("id");
+	private final long maxLatency = 30 * 1000; // 30 seconds
+	private final long rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE;
+	private final ContactId contactId = new ContactId(123);
+	private final ContactId contactId1 = new ContactId(234);
+	private final SecretKey tagKey = TestUtils.createSecretKey();
+	private final SecretKey headerKey = TestUtils.createSecretKey();
+	private final SecretKey masterKey = TestUtils.createSecretKey();
+	private final Random random = new Random();
+
+	@Test
+	public void testKeysAreRotatedAtStartup() throws Exception {
+		Mockery context = new Mockery();
+		final DatabaseComponent db = context.mock(DatabaseComponent.class);
+		final CryptoComponent crypto = context.mock(CryptoComponent.class);
+		final Timer timer = context.mock(Timer.class);
+		final Clock clock = context.mock(Clock.class);
+		final Transaction txn = new Transaction(null);
+		final Map<ContactId, TransportKeys> loaded =
+				new LinkedHashMap<ContactId, TransportKeys>();
+		final TransportKeys shouldRotate = createTransportKeys(900, 0);
+		final TransportKeys shouldNotRotate = createTransportKeys(1000, 0);
+		loaded.put(contactId, shouldRotate);
+		loaded.put(contactId1, shouldNotRotate);
+		final TransportKeys rotated = createTransportKeys(1000, 0);
+		final Transaction txn1 = new Transaction(null);
+		context.checking(new Expectations() {{
+			// Get the current time (1 ms after start of rotation period 1000)
+			oneOf(clock).currentTimeMillis();
+			will(returnValue(rotationPeriodLength * 1000 + 1));
+			// Load the transport keys
+			oneOf(db).startTransaction();
+			will(returnValue(txn));
+			oneOf(db).getTransportKeys(txn, transportId);
+			will(returnValue(loaded));
+			oneOf(db).endTransaction(txn);
+			// Rotate the transport keys
+			oneOf(crypto).rotateTransportKeys(shouldRotate, 1000);
+			will(returnValue(rotated));
+			oneOf(crypto).rotateTransportKeys(shouldNotRotate, 1000);
+			will(returnValue(shouldNotRotate));
+			// Encode the tags (3 sets per contact)
+			for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
+				exactly(6).of(crypto).encodeTag(with(any(byte[].class)),
+						with(tagKey), with(i));
+				will(new EncodeTagAction());
+			}
+			// Save the keys that were rotated
+			oneOf(db).startTransaction();
+			will(returnValue(txn1));
+			oneOf(db).updateTransportKeys(txn1,
+					Collections.singletonMap(contactId, rotated));
+			oneOf(db).endTransaction(txn1);
+			// Schedule key rotation at the start of the next rotation period
+			oneOf(timer).schedule(with(any(TimerTask.class)),
+					with(rotationPeriodLength - 1));
+		}});
+
+		TransportKeyManager transportKeyManager = new TransportKeyManager(db,
+				crypto, timer, clock, transportId, maxLatency);
+		transportKeyManager.start();
+
+		context.assertIsSatisfied();
+	}
+
+	@Test
+	public void testKeysAreRotatedWhenAddingContact() throws Exception {
+		Mockery context = new Mockery();
+		final DatabaseComponent db = context.mock(DatabaseComponent.class);
+		final CryptoComponent crypto = context.mock(CryptoComponent.class);
+		final Timer timer = context.mock(Timer.class);
+		final Clock clock = context.mock(Clock.class);
+		final boolean alice = true;
+		final TransportKeys transportKeys = createTransportKeys(999, 0);
+		final TransportKeys rotated = createTransportKeys(1000, 0);
+		final Transaction txn = new Transaction(null);
+		context.checking(new Expectations() {{
+			oneOf(crypto).deriveTransportKeys(transportId, masterKey, 999,
+					alice);
+			will(returnValue(transportKeys));
+			// Get the current time (1 ms after start of rotation period 1000)
+			oneOf(clock).currentTimeMillis();
+			will(returnValue(rotationPeriodLength * 1000 + 1));
+			// Rotate the transport keys
+			oneOf(crypto).rotateTransportKeys(transportKeys, 1000);
+			will(returnValue(rotated));
+			// Encode the tags (3 sets)
+			for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
+				exactly(3).of(crypto).encodeTag(with(any(byte[].class)),
+						with(tagKey), with(i));
+				will(new EncodeTagAction());
+			}
+			// Save the keys
+			oneOf(db).startTransaction();
+			will(returnValue(txn));
+			oneOf(db).addTransportKeys(txn, contactId, rotated);
+			oneOf(db).endTransaction(txn);
+		}});
+
+		TransportKeyManager transportKeyManager = new TransportKeyManager(db,
+				crypto, timer, clock, transportId, maxLatency);
+		// The timestamp is 1 ms before the start of rotation period 1000
+		long timestamp = rotationPeriodLength * 1000 - 1;
+		transportKeyManager.addContact(contactId, masterKey, timestamp, alice);
+
+		context.assertIsSatisfied();
+	}
+
+	@Test
+	public void testOutgoingStreamContextIsNullIfContactIsNotFound()
+			throws Exception {
+		Mockery context = new Mockery();
+		final DatabaseComponent db = context.mock(DatabaseComponent.class);
+		final CryptoComponent crypto = context.mock(CryptoComponent.class);
+		final Timer timer = context.mock(Timer.class);
+		final Clock clock = context.mock(Clock.class);
+
+		TransportKeyManager transportKeyManager = new TransportKeyManager(db,
+				crypto, timer, clock, transportId, maxLatency);
+		assertNull(transportKeyManager.getStreamContext(contactId));
+
+		context.assertIsSatisfied();
+	}
+
+	@Test
+	public void testOutgoingStreamContextIsNullIfStreamCounterIsExhausted()
+			throws Exception {
+		Mockery context = new Mockery();
+		final DatabaseComponent db = context.mock(DatabaseComponent.class);
+		final CryptoComponent crypto = context.mock(CryptoComponent.class);
+		final Timer timer = context.mock(Timer.class);
+		final Clock clock = context.mock(Clock.class);
+		final boolean alice = true;
+		// The stream counter has been exhausted
+		final TransportKeys transportKeys = createTransportKeys(1000,
+				MAX_32_BIT_UNSIGNED + 1);
+		final Transaction txn = new Transaction(null);
+		context.checking(new Expectations() {{
+			oneOf(crypto).deriveTransportKeys(transportId, masterKey, 1000,
+					alice);
+			will(returnValue(transportKeys));
+			// Get the current time (the start of rotation period 1000)
+			oneOf(clock).currentTimeMillis();
+			will(returnValue(rotationPeriodLength * 1000));
+			// Encode the tags (3 sets)
+			for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
+				exactly(3).of(crypto).encodeTag(with(any(byte[].class)),
+						with(tagKey), with(i));
+				will(new EncodeTagAction());
+			}
+			// Rotate the transport keys (the keys are unaffected)
+			oneOf(crypto).rotateTransportKeys(transportKeys, 1000);
+			will(returnValue(transportKeys));
+			// Save the keys
+			oneOf(db).startTransaction();
+			will(returnValue(txn));
+			oneOf(db).addTransportKeys(txn, contactId, transportKeys);
+			oneOf(db).endTransaction(txn);
+		}});
+
+		TransportKeyManager transportKeyManager = new TransportKeyManager(db,
+				crypto, timer, clock, transportId, maxLatency);
+		// The timestamp is at the start of rotation period 1000
+		long timestamp = rotationPeriodLength * 1000;
+		transportKeyManager.addContact(contactId, masterKey, timestamp, alice);
+		assertNull(transportKeyManager.getStreamContext(contactId));
+
+		context.assertIsSatisfied();
+	}
+
+	@Test
+	public void testOutgoingStreamCounterIsIncremented() throws Exception {
+		Mockery context = new Mockery();
+		final DatabaseComponent db = context.mock(DatabaseComponent.class);
+		final CryptoComponent crypto = context.mock(CryptoComponent.class);
+		final Timer timer = context.mock(Timer.class);
+		final Clock clock = context.mock(Clock.class);
+		final boolean alice = true;
+		// The stream counter can be used one more time before being exhausted
+		final TransportKeys transportKeys = createTransportKeys(1000,
+				MAX_32_BIT_UNSIGNED);
+		final Transaction txn = new Transaction(null);
+		final Transaction txn1 = new Transaction(null);
+		context.checking(new Expectations() {{
+			oneOf(crypto).deriveTransportKeys(transportId, masterKey, 1000,
+					alice);
+			will(returnValue(transportKeys));
+			// Get the current time (the start of rotation period 1000)
+			oneOf(clock).currentTimeMillis();
+			will(returnValue(rotationPeriodLength * 1000));
+			// Encode the tags (3 sets)
+			for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
+				exactly(3).of(crypto).encodeTag(with(any(byte[].class)),
+						with(tagKey), with(i));
+				will(new EncodeTagAction());
+			}
+			// Rotate the transport keys (the keys are unaffected)
+			oneOf(crypto).rotateTransportKeys(transportKeys, 1000);
+			will(returnValue(transportKeys));
+			// Save the keys
+			oneOf(db).startTransaction();
+			will(returnValue(txn));
+			oneOf(db).addTransportKeys(txn, contactId, transportKeys);
+			oneOf(db).endTransaction(txn);
+			// Increment the stream counter
+			oneOf(db).startTransaction();
+			will(returnValue(txn1));
+			oneOf(db).incrementStreamCounter(txn1, contactId, transportId,
+					1000);
+			oneOf(db).endTransaction(txn1);
+		}});
+
+		TransportKeyManager transportKeyManager = new TransportKeyManager(db,
+				crypto, timer, clock, transportId, maxLatency);
+		// The timestamp is at the start of rotation period 1000
+		long timestamp = rotationPeriodLength * 1000;
+		transportKeyManager.addContact(contactId, masterKey, timestamp, alice);
+		// The first request should return a stream context
+		StreamContext ctx = transportKeyManager.getStreamContext(contactId);
+		assertNotNull(ctx);
+		assertEquals(contactId, ctx.getContactId());
+		assertEquals(transportId, ctx.getTransportId());
+		assertEquals(tagKey, ctx.getTagKey());
+		assertEquals(headerKey, ctx.getHeaderKey());
+		assertEquals(MAX_32_BIT_UNSIGNED, ctx.getStreamNumber());
+		// The second request should return null, the counter is exhausted
+		assertNull(transportKeyManager.getStreamContext(contactId));
+
+		context.assertIsSatisfied();
+	}
+
 	@Test
-	public void testUnitTestsExist() {
-		fail(); // FIXME: Write tests
+	public void testIncomingStreamContextIsNullIfTagIsNotFound()
+			throws Exception {
+		Mockery context = new Mockery();
+		final DatabaseComponent db = context.mock(DatabaseComponent.class);
+		final CryptoComponent crypto = context.mock(CryptoComponent.class);
+		final Timer timer = context.mock(Timer.class);
+		final Clock clock = context.mock(Clock.class);
+		final boolean alice = true;
+		final TransportKeys transportKeys = createTransportKeys(1000, 0);
+		final Transaction txn = new Transaction(null);
+		context.checking(new Expectations() {{
+			oneOf(crypto).deriveTransportKeys(transportId, masterKey, 1000,
+					alice);
+			will(returnValue(transportKeys));
+			// Get the current time (the start of rotation period 1000)
+			oneOf(clock).currentTimeMillis();
+			will(returnValue(rotationPeriodLength * 1000));
+			// Encode the tags (3 sets)
+			for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
+				exactly(3).of(crypto).encodeTag(with(any(byte[].class)),
+						with(tagKey), with(i));
+				will(new EncodeTagAction());
+			}
+			// Rotate the transport keys (the keys are unaffected)
+			oneOf(crypto).rotateTransportKeys(transportKeys, 1000);
+			will(returnValue(transportKeys));
+			// Save the keys
+			oneOf(db).startTransaction();
+			will(returnValue(txn));
+			oneOf(db).addTransportKeys(txn, contactId, transportKeys);
+			oneOf(db).endTransaction(txn);
+		}});
+
+		TransportKeyManager transportKeyManager = new TransportKeyManager(db,
+				crypto, timer, clock, transportId, maxLatency);
+		// The timestamp is at the start of rotation period 1000
+		long timestamp = rotationPeriodLength * 1000;
+		transportKeyManager.addContact(contactId, masterKey, timestamp, alice);
+		assertNull(transportKeyManager.getStreamContext(new byte[TAG_LENGTH]));
+
+		context.assertIsSatisfied();
+	}
+
+	@Test
+	public void testTagIsNotRecognisedTwice() throws Exception {
+		Mockery context = new Mockery();
+		final DatabaseComponent db = context.mock(DatabaseComponent.class);
+		final CryptoComponent crypto = context.mock(CryptoComponent.class);
+		final Timer timer = context.mock(Timer.class);
+		final Clock clock = context.mock(Clock.class);
+		final boolean alice = true;
+		final TransportKeys transportKeys = createTransportKeys(1000, 0);
+		final Transaction txn = new Transaction(null);
+		final Transaction txn1 = new Transaction(null);
+		// Keep a copy of the tags
+		final List<byte[]> tags = new ArrayList<byte[]>();
+		context.checking(new Expectations() {{
+			oneOf(crypto).deriveTransportKeys(transportId, masterKey, 1000,
+					alice);
+			will(returnValue(transportKeys));
+			// Get the current time (the start of rotation period 1000)
+			oneOf(clock).currentTimeMillis();
+			will(returnValue(rotationPeriodLength * 1000));
+			// Encode the tags (3 sets)
+			for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
+				exactly(3).of(crypto).encodeTag(with(any(byte[].class)),
+						with(tagKey), with(i));
+				will(new EncodeTagAction(tags));
+			}
+			// Rotate the transport keys (the keys are unaffected)
+			oneOf(crypto).rotateTransportKeys(transportKeys, 1000);
+			will(returnValue(transportKeys));
+			// Save the keys
+			oneOf(db).startTransaction();
+			will(returnValue(txn));
+			oneOf(db).addTransportKeys(txn, contactId, transportKeys);
+			oneOf(db).endTransaction(txn);
+			// Encode a new tag after sliding the window
+			oneOf(crypto).encodeTag(with(any(byte[].class)),
+					with(tagKey), with((long) REORDERING_WINDOW_SIZE));
+			will(new EncodeTagAction(tags));
+			// Save the reordering window (previous rotation period, base 1)
+			oneOf(db).startTransaction();
+			will(returnValue(txn1));
+			oneOf(db).setReorderingWindow(txn1, contactId, transportId, 999,
+					1, new byte[REORDERING_WINDOW_SIZE / 8]);
+			oneOf(db).endTransaction(txn1);
+		}});
+
+		TransportKeyManager transportKeyManager = new TransportKeyManager(db,
+				crypto, timer, clock, transportId, maxLatency);
+		// The timestamp is at the start of rotation period 1000
+		long timestamp = rotationPeriodLength * 1000;
+		transportKeyManager.addContact(contactId, masterKey, timestamp, alice);
+		// Use the first tag (previous rotation period, stream number 0)
+		assertEquals(REORDERING_WINDOW_SIZE * 3, tags.size());
+		byte[] tag = tags.get(0);
+		// The first request should return a stream context
+		StreamContext ctx = transportKeyManager.getStreamContext(tag);
+		assertNotNull(ctx);
+		assertEquals(contactId, ctx.getContactId());
+		assertEquals(transportId, ctx.getTransportId());
+		assertEquals(tagKey, ctx.getTagKey());
+		assertEquals(headerKey, ctx.getHeaderKey());
+		assertEquals(0L, ctx.getStreamNumber());
+		// Another tag should have been encoded
+		assertEquals(REORDERING_WINDOW_SIZE * 3 + 1, tags.size());
+		// The second request should return null, the tag has already been used
+		assertNull(transportKeyManager.getStreamContext(tag));
+
+		context.assertIsSatisfied();
+	}
+
+	@Test
+	public void testKeysAreRotatedToCurrentPeriod() throws Exception {
+		Mockery context = new Mockery();
+		final DatabaseComponent db = context.mock(DatabaseComponent.class);
+		final CryptoComponent crypto = context.mock(CryptoComponent.class);
+		final Timer timer = context.mock(Timer.class);
+		final Clock clock = context.mock(Clock.class);
+		final Transaction txn = new Transaction(null);
+		final TransportKeys transportKeys = createTransportKeys(1000, 0);
+		final Map<ContactId, TransportKeys> loaded =
+				Collections.singletonMap(contactId, transportKeys);
+		final TransportKeys rotated = createTransportKeys(1001, 0);
+		final Transaction txn1 = new Transaction(null);
+		context.checking(new Expectations() {{
+			// Get the current time (the start of rotation period 1000)
+			oneOf(clock).currentTimeMillis();
+			will(returnValue(rotationPeriodLength * 1000));
+			// Load the transport keys
+			oneOf(db).startTransaction();
+			will(returnValue(txn));
+			oneOf(db).getTransportKeys(txn, transportId);
+			will(returnValue(loaded));
+			oneOf(db).endTransaction(txn);
+			// Rotate the transport keys (the keys are unaffected)
+			oneOf(crypto).rotateTransportKeys(transportKeys, 1000);
+			will(returnValue(transportKeys));
+			// Encode the tags (3 sets)
+			for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
+				exactly(3).of(crypto).encodeTag(with(any(byte[].class)),
+						with(tagKey), with(i));
+				will(new EncodeTagAction());
+			}
+			// Schedule key rotation at the start of the next rotation period
+			oneOf(timer).schedule(with(any(TimerTask.class)),
+					with(rotationPeriodLength));
+			will(new RunTimerTaskAction());
+			// Get the current time (the start of rotation period 1001)
+			oneOf(clock).currentTimeMillis();
+			will(returnValue(rotationPeriodLength * 1001));
+			// Rotate the transport keys
+			oneOf(crypto).rotateTransportKeys(with(any(TransportKeys.class)),
+					with(1001L));
+			will(returnValue(rotated));
+			// Encode the tags (3 sets)
+			for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
+				exactly(3).of(crypto).encodeTag(with(any(byte[].class)),
+						with(tagKey), with(i));
+				will(new EncodeTagAction());
+			}
+			// Save the keys that were rotated
+			oneOf(db).startTransaction();
+			will(returnValue(txn1));
+			oneOf(db).updateTransportKeys(txn1,
+					Collections.singletonMap(contactId, rotated));
+			oneOf(db).endTransaction(txn1);
+			// Schedule key rotation at the start of the next rotation period
+			oneOf(timer).schedule(with(any(TimerTask.class)),
+					with(rotationPeriodLength));
+		}});
+
+		TransportKeyManager transportKeyManager = new TransportKeyManager(db,
+				crypto, timer, clock, transportId, maxLatency);
+		transportKeyManager.start();
+
+		context.assertIsSatisfied();
+	}
+
+	private TransportKeys createTransportKeys(long rotationPeriod,
+			long streamCounter) {
+		IncomingKeys inPrev = new IncomingKeys(tagKey, headerKey,
+				rotationPeriod - 1);
+		IncomingKeys inCurr = new IncomingKeys(tagKey, headerKey,
+				rotationPeriod);
+		IncomingKeys inNext = new IncomingKeys(tagKey, headerKey,
+				rotationPeriod + 1);
+		OutgoingKeys outCurr = new OutgoingKeys(tagKey, headerKey,
+				rotationPeriod, streamCounter);
+		return new TransportKeys(transportId, inPrev, inCurr, inNext, outCurr);
+	}
+
+	private class EncodeTagAction implements Action {
+
+		private final Collection<byte[]> tags;
+
+		private EncodeTagAction() {
+			tags = null;
+		}
+
+		private EncodeTagAction(Collection<byte[]> tags) {
+			this.tags = tags;
+		}
+
+		@Override
+		public Object invoke(Invocation invocation) throws Throwable {
+			byte[] tag = (byte[]) invocation.getParameter(0);
+			random.nextBytes(tag);
+			if (tags != null) tags.add(tag);
+			return null;
+		}
+
+		@Override
+		public void describeTo(Description description) {
+			description.appendText("encodes a tag");
+		}
+	}
+
+	private static class RunTimerTaskAction implements Action {
+
+		@Override
+		public Object invoke(Invocation invocation) throws Throwable {
+			TimerTask task = (TimerTask) invocation.getParameter(0);
+			task.run();
+			return null;
+		}
+
+		@Override
+		public void describeTo(Description description) {
+			description.appendText("schedules a timer task");
+		}
 	}
 }