diff --git a/briar-api/src/net/sf/briar/api/transport/Endpoint.java b/briar-api/src/net/sf/briar/api/transport/Endpoint.java index 2b69df41e5a80651f8754dc5629f31fe9ae33c72..bf3dd788a93ad6a9d94a278055a9eb8f074285f5 100644 --- a/briar-api/src/net/sf/briar/api/transport/Endpoint.java +++ b/briar-api/src/net/sf/briar/api/transport/Endpoint.java @@ -5,8 +5,8 @@ import net.sf.briar.api.TransportId; public class Endpoint { - private final ContactId contactId; - private final TransportId transportId; + protected final ContactId contactId; + protected final TransportId transportId; private final long epoch; private final boolean alice; diff --git a/briar-api/src/net/sf/briar/api/transport/TemporarySecret.java b/briar-api/src/net/sf/briar/api/transport/TemporarySecret.java index 6629565b74b73c0779a04279c9b27fdc05e2208e..035ecdfae072b2424500c6f4c0c1bbf22f0ffe63 100644 --- a/briar-api/src/net/sf/briar/api/transport/TemporarySecret.java +++ b/briar-api/src/net/sf/briar/api/transport/TemporarySecret.java @@ -53,4 +53,20 @@ public class TemporarySecret extends Endpoint { public byte[] getWindowBitmap() { return bitmap; } + + @Override + public int hashCode() { + int periodHashCode = (int) (period ^ (period >>> 32)); + return contactId.hashCode() ^ transportId.hashCode() ^ periodHashCode; + } + + @Override + public boolean equals(Object o) { + if(o instanceof TemporarySecret) { + TemporarySecret s = (TemporarySecret) o; + return contactId.equals(s.contactId) && + transportId.equals(s.transportId) && period == s.period; + } + return false; + } } diff --git a/briar-core/src/net/sf/briar/db/JdbcDatabase.java b/briar-core/src/net/sf/briar/db/JdbcDatabase.java index 1aa5f77e7a41d03880c1bf495f3252d4fcb09dc4..0a443ee7acc6977c03b0a82e8e5cfa11e19c3b16 100644 --- a/briar-core/src/net/sf/briar/db/JdbcDatabase.java +++ b/briar-core/src/net/sf/briar/db/JdbcDatabase.java @@ -848,7 +848,7 @@ abstract class JdbcDatabase implements Database<Connection> { for(TemporarySecret s : secrets) { ps.setInt(1, s.getContactId().getInt()); ps.setBytes(2, s.getTransportId().getBytes()); - ps.setLong(3, s.getPeriod() - 1); + ps.setLong(3, s.getPeriod() - 2); ps.addBatch(); } batchAffected = ps.executeBatch(); diff --git a/briar-core/src/net/sf/briar/transport/KeyManagerImpl.java b/briar-core/src/net/sf/briar/transport/KeyManagerImpl.java index d72f05d1650fe43e4f014901a3457b72e9b2b147..b1b1355a0dd68944036f87368c48127f80514b5b 100644 --- a/briar-core/src/net/sf/briar/transport/KeyManagerImpl.java +++ b/briar-core/src/net/sf/briar/transport/KeyManagerImpl.java @@ -49,9 +49,9 @@ class KeyManagerImpl extends TimerTask implements KeyManager, DatabaseListener { // All of the following are locking: this private final Map<TransportId, Long> maxLatencies; - private final Map<EndpointKey, TemporarySecret> outgoing; - private final Map<EndpointKey, TemporarySecret> incomingOld; - private final Map<EndpointKey, TemporarySecret> incomingNew; + private final Map<EndpointKey, TemporarySecret> oldSecrets; + private final Map<EndpointKey, TemporarySecret> currentSecrets; + private final Map<EndpointKey, TemporarySecret> newSecrets; @Inject KeyManagerImpl(CryptoComponent crypto, DatabaseComponent db, @@ -63,9 +63,9 @@ class KeyManagerImpl extends TimerTask implements KeyManager, DatabaseListener { this.clock = clock; this.timer = timer; maxLatencies = new HashMap<TransportId, Long>(); - outgoing = new HashMap<EndpointKey, TemporarySecret>(); - incomingOld = new HashMap<EndpointKey, TemporarySecret>(); - incomingNew = new HashMap<EndpointKey, TemporarySecret>(); + oldSecrets = new HashMap<EndpointKey, TemporarySecret>(); + currentSecrets = new HashMap<EndpointKey, TemporarySecret>(); + newSecrets = new HashMap<EndpointKey, TemporarySecret>(); } public synchronized boolean start() { @@ -85,7 +85,7 @@ class KeyManagerImpl extends TimerTask implements KeyManager, DatabaseListener { // Replace any dead secrets Collection<TemporarySecret> created = replaceDeadSecrets(now, dead); if(!created.isEmpty()) { - // Store any secrets that have been created + // Store any secrets that have been created, removing any dead ones try { db.addSecrets(created); } catch(DbException e) { @@ -93,10 +93,12 @@ class KeyManagerImpl extends TimerTask implements KeyManager, DatabaseListener { return false; } } - // Pass the current incoming secrets to the recogniser - for(TemporarySecret s : incomingOld.values()) + // Pass the old, current and new secrets to the recogniser + for(TemporarySecret s : oldSecrets.values()) + connectionRecogniser.addSecret(s); + for(TemporarySecret s : currentSecrets.values()) connectionRecogniser.addSecret(s); - for(TemporarySecret s : incomingNew.values()) + for(TemporarySecret s : newSecrets.values()) connectionRecogniser.addSecret(s); // Schedule periodic key rotation timer.scheduleAtFixedRate(this, MS_BETWEEN_CHECKS, MS_BETWEEN_CHECKS); @@ -110,29 +112,24 @@ class KeyManagerImpl extends TimerTask implements KeyManager, DatabaseListener { Collection<TemporarySecret> dead = new ArrayList<TemporarySecret>(); for(TemporarySecret s : secrets) { // Discard the secret if the transport has been removed - if(!maxLatencies.containsKey(s.getTransportId())) { + Long maxLatency = maxLatencies.get(s.getTransportId()); + if(maxLatency == null) { ByteUtils.erase(s.getSecret()); continue; } - EndpointKey k = new EndpointKey(s); - long rotationPeriod = getRotationPeriod(s); - long creationTime = getCreationTime(s); - long activationTime = creationTime + MAX_CLOCK_DIFFERENCE; - long successorCreationTime = creationTime + rotationPeriod; - long deactivationTime = activationTime + rotationPeriod; - long destructionTime = successorCreationTime + rotationPeriod; + long rotation = maxLatency + MAX_CLOCK_DIFFERENCE; + long creationTime = s.getEpoch() + rotation * (s.getPeriod() - 2); + long activationTime = creationTime + rotation; + long deactivationTime = activationTime + rotation; + long destructionTime = deactivationTime + rotation; if(now >= destructionTime) { dead.add(s); } else if(now >= deactivationTime) { - incomingOld.put(k, s); - } else if(now >= successorCreationTime) { - incomingOld.put(k, s); - outgoing.put(k, s); + oldSecrets.put(new EndpointKey(s), s); } else if(now >= activationTime) { - incomingNew.put(k, s); - outgoing.put(k, s); + currentSecrets.put(new EndpointKey(s), s); } else if(now >= creationTime) { - incomingNew.put(k, s); + newSecrets.put(new EndpointKey(s), s); } else { // FIXME: Work out what to do here throw new Error("Clock has moved backwards"); @@ -147,62 +144,50 @@ class KeyManagerImpl extends TimerTask implements KeyManager, DatabaseListener { Collection<TemporarySecret> dead) { Collection<TemporarySecret> created = new ArrayList<TemporarySecret>(); for(TemporarySecret s : dead) { + Long maxLatency = maxLatencies.get(s.getTransportId()); + if(maxLatency == null) throw new IllegalStateException(); // Work out which rotation period we're in - long rotationPeriod = getRotationPeriod(s); long elapsed = now - s.getEpoch(); - long period = (elapsed / rotationPeriod) + 1; - if(period <= s.getPeriod()) throw new IllegalStateException(); - // Derive the two current incoming secrets - byte[] secret1 = s.getSecret(); - for(long p = s.getPeriod(); p < period; p++) { - byte[] temp = crypto.deriveNextSecret(secret1, p); - ByteUtils.erase(secret1); - secret1 = temp; + long rotation = maxLatency + MAX_CLOCK_DIFFERENCE; + long currentPeriod = (elapsed / rotation) + 1; + if(currentPeriod < 1) throw new IllegalStateException(); + if(currentPeriod - s.getPeriod() < 2) + throw new IllegalStateException(); + // Derive the old, current and new secrets + byte[] b1 = s.getSecret(); + for(long p = s.getPeriod() + 1; p < currentPeriod; p++) { + byte[] temp = crypto.deriveNextSecret(b1, p); + ByteUtils.erase(b1); + b1 = temp; } - byte[] secret2 = crypto.deriveNextSecret(secret1, period); - // Add the incoming secrets to their respective maps - the older - // may already exist if the dead secret has a living successor + byte[] b2 = crypto.deriveNextSecret(b1, currentPeriod); + byte[] b3 = crypto.deriveNextSecret(b2, currentPeriod + 1); + TemporarySecret s1 = new TemporarySecret(s, currentPeriod - 1, b1); + TemporarySecret s2 = new TemporarySecret(s, currentPeriod, b2); + TemporarySecret s3 = new TemporarySecret(s, currentPeriod + 1, b3); + // Add the secrets to their respective maps - the old and current + // secrets may already exist, in which case erase the duplicates EndpointKey k = new EndpointKey(s); - TemporarySecret s1 = new TemporarySecret(s, period - 1, secret1); - created.add(s1); - TemporarySecret exists = incomingOld.put(k, s1); - if(exists != null) ByteUtils.erase(exists.getSecret()); - TemporarySecret s2 = new TemporarySecret(s, period, secret2); - created.add(s2); - incomingNew.put(k, s2); - // One of the incoming secrets is the current outgoing secret - if(elapsed % rotationPeriod < MAX_CLOCK_DIFFERENCE) { - // The outgoing secret is the older incoming secret - outgoing.put(k, s1); - } else { - // The outgoing secret is the newer incoming secret - outgoing.put(k, s2); - } + TemporarySecret exists = oldSecrets.put(k, s1); + if(exists == null) created.add(s1); + else ByteUtils.erase(exists.getSecret()); + exists = currentSecrets.put(k, s2); + if(exists == null) created.add(s2); + else ByteUtils.erase(exists.getSecret()); + newSecrets.put(k, s3); + created.add(s3); } return created; } - // Locking: this - private long getRotationPeriod(Endpoint ep) { - Long maxLatency = maxLatencies.get(ep.getTransportId()); - if(maxLatency == null) throw new IllegalStateException(); - return 2 * MAX_CLOCK_DIFFERENCE + maxLatency; - } - - // Locking: this - private long getCreationTime(TemporarySecret s) { - long rotationPeriod = getRotationPeriod(s); - return s.getEpoch() + rotationPeriod * s.getPeriod(); - } - public synchronized void stop() { db.removeListener(this); timer.cancel(); connectionRecogniser.removeSecrets(); maxLatencies.clear(); - removeAndEraseSecrets(outgoing); - removeAndEraseSecrets(incomingOld); - removeAndEraseSecrets(incomingNew); + removeAndEraseSecrets(oldSecrets); + removeAndEraseSecrets(currentSecrets); + removeAndEraseSecrets(newSecrets); } // Locking: this @@ -213,7 +198,7 @@ class KeyManagerImpl extends TimerTask implements KeyManager, DatabaseListener { public synchronized ConnectionContext getConnectionContext(ContactId c, TransportId t) { - TemporarySecret s = outgoing.get(new EndpointKey(c, t)); + TemporarySecret s = currentSecrets.get(new EndpointKey(c, t)); if(s == null) return null; long connection; try { @@ -227,41 +212,36 @@ class KeyManagerImpl extends TimerTask implements KeyManager, DatabaseListener { } public synchronized void endpointAdded(Endpoint ep, byte[] initialSecret) { - if(!maxLatencies.containsKey(ep.getTransportId())) { + Long maxLatency = maxLatencies.get(ep.getTransportId()); + if(maxLatency == null) { if(LOG.isLoggable(WARNING)) LOG.warning("No such transport"); return; } // Work out which rotation period we're in - long now = clock.currentTimeMillis(); - long rotationPeriod = getRotationPeriod(ep); - long elapsed = now - ep.getEpoch(); - long period = (elapsed / rotationPeriod) + 1; - if(period < 1) throw new IllegalStateException(); - // Derive the two current incoming secrets - byte[] secret1 = initialSecret; - for(long p = 0; p < period; p++) { - byte[] temp = crypto.deriveNextSecret(secret1, p); - ByteUtils.erase(secret1); - secret1 = temp; + long elapsed = clock.currentTimeMillis() - ep.getEpoch(); + long rotation = maxLatency + MAX_CLOCK_DIFFERENCE; + long currentPeriod = (elapsed / rotation) + 1; + if(currentPeriod < 1) throw new IllegalStateException(); + // Derive the old, current and new secrets + byte[] b1 = initialSecret; + for(long p = 0; p < currentPeriod; p++) { + byte[] temp = crypto.deriveNextSecret(b1, p); + ByteUtils.erase(b1); + b1 = temp; } - byte[] secret2 = crypto.deriveNextSecret(secret1, period); + byte[] b2 = crypto.deriveNextSecret(b1, currentPeriod); + byte[] b3 = crypto.deriveNextSecret(b2, currentPeriod + 1); + TemporarySecret s1 = new TemporarySecret(ep, currentPeriod - 1, b1); + TemporarySecret s2 = new TemporarySecret(ep, currentPeriod, b2); + TemporarySecret s3 = new TemporarySecret(ep, currentPeriod + 1, b3); // Add the incoming secrets to their respective maps EndpointKey k = new EndpointKey(ep); - TemporarySecret s1 = new TemporarySecret(ep, period - 1, secret1); - incomingOld.put(k, s1); - TemporarySecret s2 = new TemporarySecret(ep, period, secret2); - incomingNew.put(k, s2); - // One of the incoming secrets is the current outgoing secret - if(elapsed % rotationPeriod < MAX_CLOCK_DIFFERENCE) { - // The outgoing secret is the older incoming secret - outgoing.put(k, s1); - } else { - // The outgoing secret is the newer incoming secret - outgoing.put(k, s2); - } + oldSecrets.put(k, s1); + currentSecrets.put(k, s2); + newSecrets.put(k, s3); // Store the new secrets try { - db.addSecrets(Arrays.asList(s1, s2)); + db.addSecrets(Arrays.asList(s1, s2, s3)); } catch(DbException e) { if(LOG.isLoggable(WARNING)) LOG.log(WARNING, e.toString(), e); return; @@ -269,17 +249,18 @@ class KeyManagerImpl extends TimerTask implements KeyManager, DatabaseListener { // Pass the new secrets to the recogniser connectionRecogniser.addSecret(s1); connectionRecogniser.addSecret(s2); + connectionRecogniser.addSecret(s3); } - @Override public synchronized void run() { // Rebuild the maps because we may be running a whole period late Collection<TemporarySecret> secrets = new ArrayList<TemporarySecret>(); - secrets.addAll(incomingOld.values()); - secrets.addAll(incomingNew.values()); - outgoing.clear(); - incomingOld.clear(); - incomingNew.clear(); + secrets.addAll(oldSecrets.values()); + secrets.addAll(currentSecrets.values()); + secrets.addAll(newSecrets.values()); + oldSecrets.clear(); + currentSecrets.clear(); + newSecrets.clear(); // Work out what phase of its lifecycle each secret is in long now = clock.currentTimeMillis(); Collection<TemporarySecret> dead = assignSecretsToMaps(now, secrets); @@ -309,9 +290,9 @@ class KeyManagerImpl extends TimerTask implements KeyManager, DatabaseListener { ContactId c = ((ContactRemovedEvent) e).getContactId(); connectionRecogniser.removeSecrets(c); synchronized(this) { - removeAndEraseSecrets(c, outgoing); - removeAndEraseSecrets(c, incomingOld); - removeAndEraseSecrets(c, incomingNew); + removeAndEraseSecrets(c, oldSecrets); + removeAndEraseSecrets(c, currentSecrets); + removeAndEraseSecrets(c, newSecrets); } } else if(e instanceof TransportAddedEvent) { TransportAddedEvent t = (TransportAddedEvent) e; @@ -323,9 +304,9 @@ class KeyManagerImpl extends TimerTask implements KeyManager, DatabaseListener { connectionRecogniser.removeSecrets(t); synchronized(this) { maxLatencies.remove(t); - removeAndEraseSecrets(t, outgoing); - removeAndEraseSecrets(t, incomingOld); - removeAndEraseSecrets(t, incomingNew); + removeAndEraseSecrets(t, oldSecrets); + removeAndEraseSecrets(t, currentSecrets); + removeAndEraseSecrets(t, newSecrets); } } } diff --git a/briar-tests/build.xml b/briar-tests/build.xml index 455a843144d9aa867c088a4703f0cedfab26b3a0..8ed3ed747d8baf9b1f12208e5c371bb59b12d5f2 100644 --- a/briar-tests/build.xml +++ b/briar-tests/build.xml @@ -103,6 +103,7 @@ <test name='net.sf.briar.transport.ConnectionWindowTest'/> <test name='net.sf.briar.transport.ConnectionWriterImplTest'/> <test name='net.sf.briar.transport.IncomingEncryptionLayerTest'/> + <test name='net.sf.briar.transport.KeyManagerImplTest'/> <test name='net.sf.briar.transport.OutgoingEncryptionLayerTest'/> <test name='net.sf.briar.transport.TransportIntegrationTest'/> <test name='net.sf.briar.transport.TransportConnectionRecogniserTest'/> diff --git a/briar-tests/src/net/sf/briar/BriarTestCase.java b/briar-tests/src/net/sf/briar/BriarTestCase.java index 32f496aef4ab891bc4d8de2f19b1deb320b36d8f..171b620cde70768e3cfd44f57d30311acddeefd7 100644 --- a/briar-tests/src/net/sf/briar/BriarTestCase.java +++ b/briar-tests/src/net/sf/briar/BriarTestCase.java @@ -1,6 +1,5 @@ package net.sf.briar; - import java.lang.Thread.UncaughtExceptionHandler; import junit.framework.TestCase; diff --git a/briar-tests/src/net/sf/briar/db/H2DatabaseTest.java b/briar-tests/src/net/sf/briar/db/H2DatabaseTest.java index c392f385e445c52b3df0ce7fea4ecae042ed9f76..a42257acd4693a299518504d885a4ae46ae7a54b 100644 --- a/briar-tests/src/net/sf/briar/db/H2DatabaseTest.java +++ b/briar-tests/src/net/sf/briar/db/H2DatabaseTest.java @@ -1511,12 +1511,13 @@ public class H2DatabaseTest extends BriarTestCase { @Test public void testTemporarySecrets() throws Exception { - // Create an endpoint and three consecutive temporary secrets + // Create an endpoint and four consecutive temporary secrets long epoch = 123, latency = 234; boolean alice = false; long outgoing1 = 345, centre1 = 456; long outgoing2 = 567, centre2 = 678; long outgoing3 = 789, centre3 = 890; + long outgoing4 = 901, centre4 = 123; Endpoint ep = new Endpoint(contactId, transportId, epoch, alice); Random random = new Random(); byte[] secret1 = new byte[32], bitmap1 = new byte[4]; @@ -1534,6 +1535,11 @@ public class H2DatabaseTest extends BriarTestCase { random.nextBytes(bitmap3); TemporarySecret s3 = new TemporarySecret(contactId, transportId, epoch, alice, 2, secret3, outgoing3, centre3, bitmap3); + byte[] secret4 = new byte[32], bitmap4 = new byte[4]; + random.nextBytes(secret4); + random.nextBytes(bitmap4); + TemporarySecret s4 = new TemporarySecret(contactId, transportId, epoch, + alice, 3, secret4, outgoing4, centre4, bitmap4); Database<Connection> db = open(false); Connection txn = db.startTransaction(); @@ -1541,18 +1547,18 @@ public class H2DatabaseTest extends BriarTestCase { // Initially there should be no secrets in the database assertEquals(Collections.emptyList(), db.getSecrets(txn)); - // Add the contact, the transport, the endpoint and the first two - // secrets (periods 0 and 1) + // Add the contact, the transport, the endpoint and the first three + // secrets (periods 0, 1 and 2) db.addLocalAuthor(txn, localAuthor); assertEquals(contactId, db.addContact(txn, author, localAuthorId)); db.addTransport(txn, transportId, latency); db.addEndpoint(txn, ep); - db.addSecrets(txn, Arrays.asList(s1, s2)); + db.addSecrets(txn, Arrays.asList(s1, s2, s3)); - // Retrieve the first two secrets + // Retrieve the first three secrets Collection<TemporarySecret> secrets = db.getSecrets(txn); - assertEquals(2, secrets.size()); - boolean foundFirst = false, foundSecond = false; + assertEquals(3, secrets.size()); + boolean foundFirst = false, foundSecond = false, foundThird = false; for(TemporarySecret s : secrets) { assertEquals(contactId, s.getContactId()); assertEquals(transportId, s.getTransportId()); @@ -1570,19 +1576,26 @@ public class H2DatabaseTest extends BriarTestCase { assertEquals(centre2, s.getWindowCentre()); assertArrayEquals(bitmap2, s.getWindowBitmap()); foundSecond = true; + } else if(s.getPeriod() == 2) { + assertArrayEquals(secret3, s.getSecret()); + assertEquals(outgoing3, s.getOutgoingConnectionCounter()); + assertEquals(centre3, s.getWindowCentre()); + assertArrayEquals(bitmap3, s.getWindowBitmap()); + foundThird = true; } else { fail(); } } assertTrue(foundFirst); assertTrue(foundSecond); + assertTrue(foundThird); - // Adding the third secret (period 2) should delete the first (period 0) - db.addSecrets(txn, Arrays.asList(s3)); + // Adding the fourth secret (period 3) should delete the first + db.addSecrets(txn, Arrays.asList(s4)); secrets = db.getSecrets(txn); - assertEquals(2, secrets.size()); - foundSecond = false; - boolean foundThird = false; + assertEquals(3, secrets.size()); + foundSecond = foundThird = false; + boolean foundFourth = false; for(TemporarySecret s : secrets) { assertEquals(contactId, s.getContactId()); assertEquals(transportId, s.getTransportId()); @@ -1600,12 +1613,19 @@ public class H2DatabaseTest extends BriarTestCase { assertEquals(centre3, s.getWindowCentre()); assertArrayEquals(bitmap3, s.getWindowBitmap()); foundThird = true; + } else if(s.getPeriod() == 3) { + assertArrayEquals(secret4, s.getSecret()); + assertEquals(outgoing4, s.getOutgoingConnectionCounter()); + assertEquals(centre4, s.getWindowCentre()); + assertArrayEquals(bitmap4, s.getWindowBitmap()); + foundFourth = true; } else { fail(); } } assertTrue(foundSecond); assertTrue(foundThird); + assertTrue(foundFourth); // Removing the contact should remove the secrets db.removeContact(txn, contactId); diff --git a/briar-tests/src/net/sf/briar/transport/KeyManagerImplTest.java b/briar-tests/src/net/sf/briar/transport/KeyManagerImplTest.java new file mode 100644 index 0000000000000000000000000000000000000000..631134b6eaffe263ac5b50a545814bdbbda0e120 --- /dev/null +++ b/briar-tests/src/net/sf/briar/transport/KeyManagerImplTest.java @@ -0,0 +1,193 @@ +package net.sf.briar.transport; + +import static net.sf.briar.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE; +import static org.junit.Assert.assertArrayEquals; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Random; +import java.util.TimerTask; + +import net.sf.briar.BriarTestCase; +import net.sf.briar.TestUtils; +import net.sf.briar.api.ContactId; +import net.sf.briar.api.TransportId; +import net.sf.briar.api.clock.Clock; +import net.sf.briar.api.clock.Timer; +import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.api.db.DatabaseComponent; +import net.sf.briar.api.db.event.DatabaseListener; +import net.sf.briar.api.transport.ConnectionRecogniser; +import net.sf.briar.api.transport.Endpoint; +import net.sf.briar.api.transport.TemporarySecret; + +import org.jmock.Expectations; +import org.jmock.Mockery; +import org.junit.Before; +import org.junit.Test; + +public class KeyManagerImplTest extends BriarTestCase { + + private final Random random = new Random(); + private final ContactId contactId; + private final TransportId transportId; + private final long maxLatency; + private final long rotationPeriodLength; + private final byte[] secret0, secret1, secret2, secret3; + private final long epoch = 1000L * 1000L * 1000L * 1000L; + + public KeyManagerImplTest() { + contactId = new ContactId(234); + transportId = new TransportId(TestUtils.getRandomId()); + maxLatency = 2 * 60 * 1000; // 2 minutes + rotationPeriodLength = maxLatency + MAX_CLOCK_DIFFERENCE; + secret0 = new byte[32]; + secret1 = new byte[32]; + secret2 = new byte[32]; + secret3 = new byte[32]; + } + + @Before + public void setUp() { + random.nextBytes(secret0); + random.nextBytes(secret1); + random.nextBytes(secret2); + random.nextBytes(secret3); + } + + @Test + public void testStartAndStop() throws Exception { + Mockery context = new Mockery(); + final CryptoComponent crypto = context.mock(CryptoComponent.class); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final ConnectionRecogniser connectionRecogniser = + context.mock(ConnectionRecogniser.class); + final Clock clock = context.mock(Clock.class); + final Timer timer = context.mock(Timer.class); + final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, + connectionRecogniser, clock, timer); + context.checking(new Expectations() {{ + // start() + oneOf(db).addListener(with(any(DatabaseListener.class))); + oneOf(db).getSecrets(); + will(returnValue(Collections.emptyList())); + oneOf(db).getTransportLatencies(); + will(returnValue(Collections.emptyMap())); + oneOf(clock).currentTimeMillis(); + will(returnValue(epoch)); + oneOf(timer).scheduleAtFixedRate(with(any(TimerTask.class)), + with(any(long.class)), with(any(long.class))); + // stop() + oneOf(db).removeListener(with(any(DatabaseListener.class))); + oneOf(timer).cancel(); + oneOf(connectionRecogniser).removeSecrets(); + }}); + + assertTrue(keyManager.start()); + keyManager.stop(); + + context.assertIsSatisfied(); + } + + @Test + public void testLoadSecretsAtEpoch() throws Exception { + Mockery context = new Mockery(); + final CryptoComponent crypto = context.mock(CryptoComponent.class); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final ConnectionRecogniser connectionRecogniser = + context.mock(ConnectionRecogniser.class); + final Clock clock = context.mock(Clock.class); + final Timer timer = context.mock(Timer.class); + final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, + connectionRecogniser, clock, timer); + // The DB contains secrets for periods 0 - 2 + Endpoint ep = new Endpoint(contactId, transportId, epoch, true); + final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0); + final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1); + final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2); + context.checking(new Expectations() {{ + // start() + oneOf(db).addListener(with(any(DatabaseListener.class))); + oneOf(db).getSecrets(); + will(returnValue(Arrays.asList(s0, s1, s2))); + oneOf(db).getTransportLatencies(); + will(returnValue(Collections.singletonMap(transportId, + maxLatency))); + // The current time is the second secret's activation time + oneOf(clock).currentTimeMillis(); + will(returnValue(epoch)); + // The secrets for periods 0 - 2 should be added to the recogniser + oneOf(connectionRecogniser).addSecret(s0); + oneOf(connectionRecogniser).addSecret(s1); + oneOf(connectionRecogniser).addSecret(s2); + oneOf(timer).scheduleAtFixedRate(with(any(TimerTask.class)), + with(any(long.class)), with(any(long.class))); + // stop() + oneOf(db).removeListener(with(any(DatabaseListener.class))); + oneOf(timer).cancel(); + oneOf(connectionRecogniser).removeSecrets(); + }}); + + assertTrue(keyManager.start()); + keyManager.stop(); + + context.assertIsSatisfied(); + } + + @Test + public void testLoadSecretsAtNewActivationTime() throws Exception { + Mockery context = new Mockery(); + final CryptoComponent crypto = context.mock(CryptoComponent.class); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final ConnectionRecogniser connectionRecogniser = + context.mock(ConnectionRecogniser.class); + final Clock clock = context.mock(Clock.class); + final Timer timer = context.mock(Timer.class); + final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, + connectionRecogniser, clock, timer); + // The DB contains secrets for periods 0 - 2 + Endpoint ep = new Endpoint(contactId, transportId, epoch, true); + final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0); + final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1); + final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2); + // A fourth secret should be derived and stored + final TemporarySecret s3 = new TemporarySecret(ep, 3, secret3); + context.checking(new Expectations() {{ + // start() + oneOf(db).addListener(with(any(DatabaseListener.class))); + oneOf(db).getSecrets(); + will(returnValue(Arrays.asList(s0, s1, s2))); + oneOf(db).getTransportLatencies(); + will(returnValue(Collections.singletonMap(transportId, + maxLatency))); + // The current time is the third secret's activation time + oneOf(clock).currentTimeMillis(); + will(returnValue(epoch + rotationPeriodLength)); + // A fourth secret should be derived and stored + oneOf(crypto).deriveNextSecret(secret0, 1); + will(returnValue(secret1.clone())); + oneOf(crypto).deriveNextSecret(secret1, 2); + will(returnValue(secret2.clone())); + oneOf(crypto).deriveNextSecret(secret2, 3); + will(returnValue(secret3)); + oneOf(db).addSecrets(Arrays.asList(s3)); + // The secrets for periods 1 - 3 should be added to the recogniser + oneOf(connectionRecogniser).addSecret(s1); + oneOf(connectionRecogniser).addSecret(s2); + oneOf(connectionRecogniser).addSecret(s3); + oneOf(timer).scheduleAtFixedRate(with(any(TimerTask.class)), + with(any(long.class)), with(any(long.class))); + // stop() + oneOf(db).removeListener(with(any(DatabaseListener.class))); + oneOf(timer).cancel(); + oneOf(connectionRecogniser).removeSecrets(); + }}); + + assertTrue(keyManager.start()); + // The dead secret should have been erased + assertArrayEquals(new byte[32], secret0); + keyManager.stop(); + + context.assertIsSatisfied(); + } +}