diff --git a/test/build.xml b/test/build.xml index 2791cc2a029c6e1bb278c112185650a450cd6bae..d008290f4224b3453dee129229e291d929aa16a2 100644 --- a/test/build.xml +++ b/test/build.xml @@ -32,6 +32,7 @@ <test name='net.sf.briar.serial.ReaderImplTest'/> <test name='net.sf.briar.serial.WriterImplTest'/> <test name='net.sf.briar.setup.SetupWorkerTest'/> + <test name='net.sf.briar.transport.ConnectionRecogniserImplTest'/> <test name='net.sf.briar.transport.ConnectionWindowImplTest'/> <test name='net.sf.briar.transport.PacketEncrypterImplTest'/> <test name='net.sf.briar.transport.PacketWriterImplTest'/> diff --git a/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java b/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java new file mode 100644 index 0000000000000000000000000000000000000000..7ec47deb53e50043c37a41a811f03538b7ef14d5 --- /dev/null +++ b/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java @@ -0,0 +1,98 @@ +package net.sf.briar.transport; + +import java.util.Collection; +import java.util.Collections; + +import javax.crypto.Cipher; +import javax.crypto.SecretKey; + +import junit.framework.TestCase; +import net.sf.briar.api.ContactId; +import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.api.db.DatabaseComponent; +import net.sf.briar.api.transport.ConnectionWindow; +import net.sf.briar.crypto.CryptoModule; + +import org.jmock.Expectations; +import org.jmock.Mockery; +import org.junit.Test; + +import com.google.inject.Guice; +import com.google.inject.Injector; + +public class ConnectionRecogniserImplTest extends TestCase { + + private final CryptoComponent crypto; + private final ContactId contactId; + private final byte[] secret; + private final int transportId; + private final ConnectionWindow connectionWindow; + + public ConnectionRecogniserImplTest() { + super(); + Injector i = Guice.createInjector(new CryptoModule()); + crypto = i.getInstance(CryptoComponent.class); + contactId = new ContactId(1); + secret = new byte[18]; + transportId = 123; + connectionWindow = new ConnectionWindowImpl(0L, 0); + } + + @Test + public void testUnexpectedTag() throws Exception { + Mockery context = new Mockery(); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + context.checking(new Expectations() {{ + oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class))); + oneOf(db).getContacts(); + will(returnValue(Collections.singleton(contactId))); + oneOf(db).getSharedSecret(contactId); + will(returnValue(secret)); + oneOf(db).getConnectionWindow(contactId, transportId); + will(returnValue(connectionWindow)); + }}); + final ConnectionRecogniserImpl c = + new ConnectionRecogniserImpl(transportId, crypto, db); + assertNull(c.acceptConnection(new byte[16])); + context.assertIsSatisfied(); + } + + @Test + public void testExpectedTag() throws Exception { + // Calculate the expected tag for connection number 3 + SecretKey tagKey = crypto.deriveTagKey(secret); + Cipher tagCipher = crypto.getTagCipher(); + tagCipher.init(Cipher.ENCRYPT_MODE, tagKey); + byte[] tag = TagEncoder.encodeTag(transportId, 3L, 0); + byte[] encryptedTag = tagCipher.doFinal(tag); + + Mockery context = new Mockery(); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + context.checking(new Expectations() {{ + oneOf(db).addListener(with(any(ConnectionRecogniserImpl.class))); + oneOf(db).getContacts(); + will(returnValue(Collections.singleton(contactId))); + oneOf(db).getSharedSecret(contactId); + will(returnValue(secret)); + oneOf(db).getConnectionWindow(contactId, transportId); + will(returnValue(connectionWindow)); + oneOf(db).setConnectionWindow(contactId, transportId, + connectionWindow); + }}); + final ConnectionRecogniserImpl c = + new ConnectionRecogniserImpl(transportId, crypto, db); + // First time - the tag should be expected + assertEquals(contactId, c.acceptConnection(encryptedTag)); + // Second time - the tag should no longer be expected + assertNull(c.acceptConnection(encryptedTag)); + // The window should have advanced + assertEquals(4L, connectionWindow.getCentre()); + Collection<Long> unseen = connectionWindow.getUnseenConnectionNumbers(); + assertEquals(19, unseen.size()); + for(int i = 0; i < 19; i++) { + if(i == 3) continue; + assertTrue(unseen.contains(Long.valueOf(i))); + } + context.assertIsSatisfied(); + } +}