diff --git a/components/net/sf/briar/protocol/AckReader.java b/components/net/sf/briar/protocol/AckReader.java index 977eaa8dabdc855fcd5e2471d90db2f89db80810..0e89c6bcb1d1588b0ddd19fb1f1edd3d5e9f3473 100644 --- a/components/net/sf/briar/protocol/AckReader.java +++ b/components/net/sf/briar/protocol/AckReader.java @@ -1,5 +1,7 @@ package net.sf.briar.protocol; +import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH; + import java.io.IOException; import java.util.ArrayList; import java.util.Collections; @@ -10,7 +12,6 @@ import net.sf.briar.api.FormatException; import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.PacketFactory; -import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.serial.Consumer; @@ -28,8 +29,7 @@ class AckReader implements ObjectReader<Ack> { public Ack readObject(Reader r) throws IOException { // Initialise the consumer - Consumer counting = - new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH); + Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH); // Read the data r.addConsumer(counting); r.readStructId(Types.ACK); diff --git a/components/net/sf/briar/protocol/BatchReader.java b/components/net/sf/briar/protocol/BatchReader.java index 4864d6121a7dfb96fd88d17b12e3f28baf85e7f1..55031796a11fa1c1e225b4eb353372106364a9ee 100644 --- a/components/net/sf/briar/protocol/BatchReader.java +++ b/components/net/sf/briar/protocol/BatchReader.java @@ -1,52 +1,41 @@ package net.sf.briar.protocol; +import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH; + import java.io.IOException; import java.util.List; import net.sf.briar.api.FormatException; -import net.sf.briar.api.crypto.CryptoComponent; -import net.sf.briar.api.crypto.MessageDigest; -import net.sf.briar.api.protocol.BatchId; -import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.UnverifiedBatch; import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.CountingConsumer; -import net.sf.briar.api.serial.DigestingConsumer; import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.Reader; class BatchReader implements ObjectReader<UnverifiedBatch> { - private final MessageDigest messageDigest; private final ObjectReader<UnverifiedMessage> messageReader; private final UnverifiedBatchFactory batchFactory; - BatchReader(CryptoComponent crypto, - ObjectReader<UnverifiedMessage> messageReader, + BatchReader(ObjectReader<UnverifiedMessage> messageReader, UnverifiedBatchFactory batchFactory) { - messageDigest = crypto.getMessageDigest(); this.messageReader = messageReader; this.batchFactory = batchFactory; } public UnverifiedBatch readObject(Reader r) throws IOException { - // Initialise the consumers - Consumer counting = - new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH); - DigestingConsumer digesting = new DigestingConsumer(messageDigest); - // Read and digest the data + // Initialise the consumer + Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH); + // Read the data r.addConsumer(counting); - r.addConsumer(digesting); r.readStructId(Types.BATCH); r.addObjectReader(Types.MESSAGE, messageReader); List<UnverifiedMessage> messages = r.readList(UnverifiedMessage.class); r.removeObjectReader(Types.MESSAGE); - r.removeConsumer(digesting); r.removeConsumer(counting); if(messages.isEmpty()) throw new FormatException(); // Build and return the batch - BatchId id = new BatchId(messageDigest.digest()); - return batchFactory.createUnverifiedBatch(id, messages); + return batchFactory.createUnverifiedBatch( messages); } } diff --git a/components/net/sf/briar/protocol/MessageReader.java b/components/net/sf/briar/protocol/MessageReader.java index 4c7b8a1a6da276e5cc8e521c30ad243c924aa2df..7a4ddc22f9ff5bef6c1b38405393dad77d35fe36 100644 --- a/components/net/sf/briar/protocol/MessageReader.java +++ b/components/net/sf/briar/protocol/MessageReader.java @@ -1,12 +1,17 @@ package net.sf.briar.protocol; +import static net.sf.briar.api.protocol.ProtocolConstants.MAX_BODY_LENGTH; +import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH; +import static net.sf.briar.api.protocol.ProtocolConstants.MAX_SIGNATURE_LENGTH; +import static net.sf.briar.api.protocol.ProtocolConstants.MAX_SUBJECT_LENGTH; +import static net.sf.briar.api.protocol.ProtocolConstants.SALT_LENGTH; + import java.io.IOException; import net.sf.briar.api.FormatException; import net.sf.briar.api.protocol.Author; import net.sf.briar.api.protocol.Group; import net.sf.briar.api.protocol.MessageId; -import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.serial.CopyingConsumer; @@ -27,8 +32,7 @@ class MessageReader implements ObjectReader<UnverifiedMessage> { public UnverifiedMessage readObject(Reader r) throws IOException { CopyingConsumer copying = new CopyingConsumer(); - CountingConsumer counting = - new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH); + CountingConsumer counting = new CountingConsumer(MAX_PACKET_LENGTH); r.addConsumer(copying); r.addConsumer(counting); // Read the initial tag @@ -61,16 +65,15 @@ class MessageReader implements ObjectReader<UnverifiedMessage> { r.removeObjectReader(Types.AUTHOR); } // Read the subject - String subject = r.readString(ProtocolConstants.MAX_SUBJECT_LENGTH); + String subject = r.readString(MAX_SUBJECT_LENGTH); // Read the timestamp long timestamp = r.readInt64(); if(timestamp < 0L) throw new FormatException(); // Read the salt - byte[] salt = r.readBytes(ProtocolConstants.SALT_LENGTH); - if(salt.length != ProtocolConstants.SALT_LENGTH) - throw new FormatException(); + byte[] salt = r.readBytes(SALT_LENGTH); + if(salt.length != SALT_LENGTH) throw new FormatException(); // Read the message body - byte[] body = r.readBytes(ProtocolConstants.MAX_BODY_LENGTH); + byte[] body = r.readBytes(MAX_BODY_LENGTH); // Record the offset of the body within the message int bodyStart = (int) counting.getCount() - body.length; // Record the length of the data covered by the author's signature @@ -78,13 +81,13 @@ class MessageReader implements ObjectReader<UnverifiedMessage> { // Read the author's signature, if there is one byte[] authorSig = null; if(author == null) r.readNull(); - else authorSig = r.readBytes(ProtocolConstants.MAX_SIGNATURE_LENGTH); + else authorSig = r.readBytes(MAX_SIGNATURE_LENGTH); // Record the length of the data covered by the group's signature int signedByGroup = (int) counting.getCount(); // Read the group's signature, if there is one byte[] groupSig = null; if(group == null || group.getPublicKey() == null) r.readNull(); - else groupSig = r.readBytes(ProtocolConstants.MAX_SIGNATURE_LENGTH); + else groupSig = r.readBytes(MAX_SIGNATURE_LENGTH); // That's all, folks r.removeConsumer(counting); r.removeConsumer(copying); diff --git a/components/net/sf/briar/protocol/OfferReader.java b/components/net/sf/briar/protocol/OfferReader.java index d7e7dcd878b799fa453ab45a88140678ac53b80d..48e3edf5ed311e67f7b06274fc29dc3e9b5cd1ea 100644 --- a/components/net/sf/briar/protocol/OfferReader.java +++ b/components/net/sf/briar/protocol/OfferReader.java @@ -1,5 +1,7 @@ package net.sf.briar.protocol; +import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH; + import java.io.IOException; import java.util.ArrayList; import java.util.Collections; @@ -10,7 +12,6 @@ import net.sf.briar.api.FormatException; import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.PacketFactory; -import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.serial.Consumer; @@ -28,8 +29,7 @@ class OfferReader implements ObjectReader<Offer> { public Offer readObject(Reader r) throws IOException { // Initialise the consumer - Consumer counting = - new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH); + Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH); // Read the data r.addConsumer(counting); r.readStructId(Types.OFFER); diff --git a/components/net/sf/briar/protocol/ProtocolModule.java b/components/net/sf/briar/protocol/ProtocolModule.java index d3eda315b527d01f430cfbc71e1833fb64b0551a..b85e75d5351b04ee232da857dc7bddef4e1e456c 100644 --- a/components/net/sf/briar/protocol/ProtocolModule.java +++ b/components/net/sf/briar/protocol/ProtocolModule.java @@ -51,10 +51,10 @@ public class ProtocolModule extends AbstractModule { } @Provides - ObjectReader<UnverifiedBatch> getBatchReader(CryptoComponent crypto, + ObjectReader<UnverifiedBatch> getBatchReader( ObjectReader<UnverifiedMessage> messageReader, UnverifiedBatchFactory batchFactory) { - return new BatchReader(crypto, messageReader, batchFactory); + return new BatchReader(messageReader, batchFactory); } @Provides diff --git a/components/net/sf/briar/protocol/RequestReader.java b/components/net/sf/briar/protocol/RequestReader.java index cab6d20d7a9f4c18926c83cee48db7609ed40264..005bc0124e596ddb3cc44aaf12083c90ed76d358 100644 --- a/components/net/sf/briar/protocol/RequestReader.java +++ b/components/net/sf/briar/protocol/RequestReader.java @@ -1,11 +1,12 @@ package net.sf.briar.protocol; +import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH; + import java.io.IOException; import java.util.BitSet; import net.sf.briar.api.FormatException; import net.sf.briar.api.protocol.PacketFactory; -import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.Types; import net.sf.briar.api.serial.Consumer; @@ -23,14 +24,13 @@ class RequestReader implements ObjectReader<Request> { public Request readObject(Reader r) throws IOException { // Initialise the consumer - Consumer counting = - new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH); + Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH); // Read the data r.addConsumer(counting); r.readStructId(Types.REQUEST); int padding = r.readUint7(); if(padding > 7) throw new FormatException(); - byte[] bitmap = r.readBytes(ProtocolConstants.MAX_PACKET_LENGTH); + byte[] bitmap = r.readBytes(MAX_PACKET_LENGTH); r.removeConsumer(counting); // Convert the bitmap into a BitSet int length = bitmap.length * 8 - padding; diff --git a/components/net/sf/briar/protocol/SubscriptionUpdateReader.java b/components/net/sf/briar/protocol/SubscriptionUpdateReader.java index 74b55089887cbb9ed91c5ff942c88e0234d62b76..2e2dd26543290c2c753d87ae293f7ed9fd2a39de 100644 --- a/components/net/sf/briar/protocol/SubscriptionUpdateReader.java +++ b/components/net/sf/briar/protocol/SubscriptionUpdateReader.java @@ -1,12 +1,13 @@ package net.sf.briar.protocol; +import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH; + import java.io.IOException; import java.util.Map; import net.sf.briar.api.FormatException; import net.sf.briar.api.protocol.Group; import net.sf.briar.api.protocol.PacketFactory; -import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.Types; import net.sf.briar.api.serial.Consumer; @@ -27,8 +28,7 @@ class SubscriptionUpdateReader implements ObjectReader<SubscriptionUpdate> { public SubscriptionUpdate readObject(Reader r) throws IOException { // Initialise the consumer - Consumer counting = - new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH); + Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH); // Read the data r.addConsumer(counting); r.readStructId(Types.SUBSCRIPTION_UPDATE); diff --git a/components/net/sf/briar/protocol/TransportUpdateReader.java b/components/net/sf/briar/protocol/TransportUpdateReader.java index 7b52cec7da57bd9a1a6eaf026e91758d0da2d525..bbea4948f7e71e1a62838f708309457e421a4a77 100644 --- a/components/net/sf/briar/protocol/TransportUpdateReader.java +++ b/components/net/sf/briar/protocol/TransportUpdateReader.java @@ -1,5 +1,10 @@ package net.sf.briar.protocol; +import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH; +import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PROPERTIES_PER_TRANSPORT; +import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PROPERTY_LENGTH; +import static net.sf.briar.api.protocol.ProtocolConstants.MAX_TRANSPORTS; + import java.io.IOException; import java.util.Collection; import java.util.HashSet; @@ -8,7 +13,6 @@ import java.util.Set; import net.sf.briar.api.FormatException; import net.sf.briar.api.protocol.PacketFactory; -import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportIndex; @@ -32,16 +36,14 @@ class TransportUpdateReader implements ObjectReader<TransportUpdate> { public TransportUpdate readObject(Reader r) throws IOException { // Initialise the consumer - Consumer counting = - new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH); + Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH); // Read the data r.addConsumer(counting); r.readStructId(Types.TRANSPORT_UPDATE); r.addObjectReader(Types.TRANSPORT, transportReader); Collection<Transport> transports = r.readList(Transport.class); r.removeObjectReader(Types.TRANSPORT); - if(transports.size() > ProtocolConstants.MAX_TRANSPORTS) - throw new FormatException(); + if(transports.size() > MAX_TRANSPORTS) throw new FormatException(); long timestamp = r.readInt64(); r.removeConsumer(counting); // Check for duplicate IDs or indices @@ -65,14 +67,13 @@ class TransportUpdateReader implements ObjectReader<TransportUpdate> { TransportId id = new TransportId(b); // Read the index int i = r.readInt32(); - if(i < 0 || i >= ProtocolConstants.MAX_TRANSPORTS) - throw new FormatException(); + if(i < 0 || i >= MAX_TRANSPORTS) throw new FormatException(); TransportIndex index = new TransportIndex(i); // Read the properties - r.setMaxStringLength(ProtocolConstants.MAX_PROPERTY_LENGTH); + r.setMaxStringLength(MAX_PROPERTY_LENGTH); Map<String, String> m = r.readMap(String.class, String.class); r.resetMaxStringLength(); - if(m.size() > ProtocolConstants.MAX_PROPERTIES_PER_TRANSPORT) + if(m.size() > MAX_PROPERTIES_PER_TRANSPORT) throw new FormatException(); return new Transport(id, index, m); } diff --git a/components/net/sf/briar/protocol/UnverifiedBatchFactory.java b/components/net/sf/briar/protocol/UnverifiedBatchFactory.java index c420dd6b688e06b36a5e2adc223b933d90d4fae8..93e20ac8aa166d0aec4a0c1a477038c5aa6e92d3 100644 --- a/components/net/sf/briar/protocol/UnverifiedBatchFactory.java +++ b/components/net/sf/briar/protocol/UnverifiedBatchFactory.java @@ -2,11 +2,10 @@ package net.sf.briar.protocol; import java.util.Collection; -import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.UnverifiedBatch; interface UnverifiedBatchFactory { - UnverifiedBatch createUnverifiedBatch(BatchId id, + UnverifiedBatch createUnverifiedBatch( Collection<UnverifiedMessage> messages); } diff --git a/components/net/sf/briar/protocol/UnverifiedBatchFactoryImpl.java b/components/net/sf/briar/protocol/UnverifiedBatchFactoryImpl.java index b762a278d1caad7679f7d2ee60b02735b7655433..980c09db5758b2f919ee66fece3cf9839396ece8 100644 --- a/components/net/sf/briar/protocol/UnverifiedBatchFactoryImpl.java +++ b/components/net/sf/briar/protocol/UnverifiedBatchFactoryImpl.java @@ -3,7 +3,6 @@ package net.sf.briar.protocol; import java.util.Collection; import net.sf.briar.api.crypto.CryptoComponent; -import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.UnverifiedBatch; import com.google.inject.Inject; @@ -17,8 +16,8 @@ class UnverifiedBatchFactoryImpl implements UnverifiedBatchFactory { this.crypto = crypto; } - public UnverifiedBatch createUnverifiedBatch(BatchId id, + public UnverifiedBatch createUnverifiedBatch( Collection<UnverifiedMessage> messages) { - return new UnverifiedBatchImpl(crypto, id, messages); + return new UnverifiedBatchImpl(crypto, messages); } } diff --git a/components/net/sf/briar/protocol/UnverifiedBatchImpl.java b/components/net/sf/briar/protocol/UnverifiedBatchImpl.java index 8665de30a4aaa0d8e120c000b0e2d7cb42925849..a1e24354c8cfae89a24ee9cd725725cf08f34da1 100644 --- a/components/net/sf/briar/protocol/UnverifiedBatchImpl.java +++ b/components/net/sf/briar/protocol/UnverifiedBatchImpl.java @@ -24,32 +24,34 @@ import net.sf.briar.api.protocol.UnverifiedBatch; class UnverifiedBatchImpl implements UnverifiedBatch { private final CryptoComponent crypto; - private final BatchId id; private final Collection<UnverifiedMessage> messages; + private final MessageDigest batchDigest, messageDigest; - // Initialise lazily - the batch may be empty or contain unsigned messages - private MessageDigest messageDigest = null; + // Initialise lazily - the batch may contain unsigned messages private KeyParser keyParser = null; private Signature signature = null; - UnverifiedBatchImpl(CryptoComponent crypto, BatchId id, + UnverifiedBatchImpl(CryptoComponent crypto, Collection<UnverifiedMessage> messages) { this.crypto = crypto; - this.id = id; this.messages = messages; + batchDigest = crypto.getMessageDigest(); + messageDigest = crypto.getMessageDigest(); } public Batch verify() throws GeneralSecurityException { List<Message> verified = new ArrayList<Message>(); for(UnverifiedMessage m : messages) verified.add(verify(m)); + BatchId id = new BatchId(batchDigest.digest()); return new BatchImpl(id, Collections.unmodifiableList(verified)); } private Message verify(UnverifiedMessage m) throws GeneralSecurityException { - // Hash the message, including the signatures, to get the message ID + // The batch ID is the hash of the concatenated messages byte[] raw = m.getRaw(); - if(messageDigest == null) messageDigest = crypto.getMessageDigest(); + batchDigest.update(raw); + // Hash the message, including the signatures, to get the message ID messageDigest.update(raw); MessageId id = new MessageId(messageDigest.digest()); // Verify the author's signature, if there is one diff --git a/test/build.xml b/test/build.xml index d865f5bd5a1c0e90eed1c6299164417775145d1e..a37b43c2434b7553435dc81e2f61095754b5ff75 100644 --- a/test/build.xml +++ b/test/build.xml @@ -43,6 +43,7 @@ <test name='net.sf.briar.protocol.ProtocolReadWriteTest'/> <test name='net.sf.briar.protocol.ProtocolWriterImplTest'/> <test name='net.sf.briar.protocol.RequestReaderTest'/> + <test name='net.sf.briar.protocol.UnverifiedBatchImplTest'/> <test name='net.sf.briar.serial.ReaderImplTest'/> <test name='net.sf.briar.serial.WriterImplTest'/> <test name='net.sf.briar.setup.SetupWorkerTest'/> diff --git a/test/net/sf/briar/protocol/BatchReaderTest.java b/test/net/sf/briar/protocol/BatchReaderTest.java index f24722a1ff3e9e04cc84f04e699f7902f52f2f0f..4bd87ad252390bbc7dbd463d814ff957709fa8ab 100644 --- a/test/net/sf/briar/protocol/BatchReaderTest.java +++ b/test/net/sf/briar/protocol/BatchReaderTest.java @@ -7,9 +7,6 @@ import java.util.Collections; import junit.framework.TestCase; import net.sf.briar.api.FormatException; -import net.sf.briar.api.crypto.CryptoComponent; -import net.sf.briar.api.crypto.MessageDigest; -import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.UnverifiedBatch; @@ -18,7 +15,6 @@ import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.ReaderFactory; import net.sf.briar.api.serial.Writer; import net.sf.briar.api.serial.WriterFactory; -import net.sf.briar.crypto.CryptoModule; import net.sf.briar.serial.SerialModule; import org.jmock.Expectations; @@ -32,18 +28,15 @@ public class BatchReaderTest extends TestCase { private final ReaderFactory readerFactory; private final WriterFactory writerFactory; - private final CryptoComponent crypto; private final Mockery context; private final UnverifiedMessage message; private final ObjectReader<UnverifiedMessage> messageReader; public BatchReaderTest() throws Exception { super(); - Injector i = Guice.createInjector(new SerialModule(), - new CryptoModule()); + Injector i = Guice.createInjector(new SerialModule()); readerFactory = i.getInstance(ReaderFactory.class); writerFactory = i.getInstance(WriterFactory.class); - crypto = i.getInstance(CryptoComponent.class); context = new Mockery(); message = context.mock(UnverifiedMessage.class); messageReader = new TestMessageReader(); @@ -53,8 +46,7 @@ public class BatchReaderTest extends TestCase { public void testFormatExceptionIfBatchIsTooLarge() throws Exception { UnverifiedBatchFactory batchFactory = context.mock(UnverifiedBatchFactory.class); - BatchReader batchReader = new BatchReader(crypto, messageReader, - batchFactory); + BatchReader batchReader = new BatchReader(messageReader, batchFactory); byte[] b = createBatch(ProtocolConstants.MAX_PACKET_LENGTH + 1); ByteArrayInputStream in = new ByteArrayInputStream(b); @@ -72,12 +64,11 @@ public class BatchReaderTest extends TestCase { public void testNoFormatExceptionIfBatchIsMaximumSize() throws Exception { final UnverifiedBatchFactory batchFactory = context.mock(UnverifiedBatchFactory.class); - BatchReader batchReader = new BatchReader(crypto, messageReader, - batchFactory); + BatchReader batchReader = new BatchReader(messageReader, batchFactory); final UnverifiedBatch batch = context.mock(UnverifiedBatch.class); context.checking(new Expectations() {{ - oneOf(batchFactory).createUnverifiedBatch(with(any(BatchId.class)), - with(Collections.singletonList(message))); + oneOf(batchFactory).createUnverifiedBatch( + Collections.singletonList(message)); will(returnValue(batch)); }}); @@ -91,41 +82,11 @@ public class BatchReaderTest extends TestCase { context.assertIsSatisfied(); } - @Test - public void testBatchId() throws Exception { - byte[] b = createBatch(ProtocolConstants.MAX_PACKET_LENGTH); - // Calculate the expected batch ID - MessageDigest messageDigest = crypto.getMessageDigest(); - messageDigest.update(b); - final BatchId id = new BatchId(messageDigest.digest()); - - final UnverifiedBatchFactory batchFactory = - context.mock(UnverifiedBatchFactory.class); - BatchReader batchReader = new BatchReader(crypto, messageReader, - batchFactory); - final UnverifiedBatch batch = context.mock(UnverifiedBatch.class); - context.checking(new Expectations() {{ - // Check that the batch ID matches the expected ID - oneOf(batchFactory).createUnverifiedBatch(with(id), - with(Collections.singletonList(message))); - will(returnValue(batch)); - }}); - - ByteArrayInputStream in = new ByteArrayInputStream(b); - Reader reader = readerFactory.createReader(in); - reader.addObjectReader(Types.BATCH, batchReader); - - assertEquals(batch, reader.readStruct(Types.BATCH, - UnverifiedBatch.class)); - context.assertIsSatisfied(); - } - @Test public void testEmptyBatch() throws Exception { final UnverifiedBatchFactory batchFactory = context.mock(UnverifiedBatchFactory.class); - BatchReader batchReader = new BatchReader(crypto, messageReader, - batchFactory); + BatchReader batchReader = new BatchReader(messageReader, batchFactory); byte[] b = createEmptyBatch(); ByteArrayInputStream in = new ByteArrayInputStream(b); diff --git a/test/net/sf/briar/protocol/UnverifiedBatchImplTest.java b/test/net/sf/briar/protocol/UnverifiedBatchImplTest.java new file mode 100644 index 0000000000000000000000000000000000000000..c0924c22ab2a2dd4f8d48cb25c734d626e7f567c --- /dev/null +++ b/test/net/sf/briar/protocol/UnverifiedBatchImplTest.java @@ -0,0 +1,242 @@ +package net.sf.briar.protocol; + +import java.security.GeneralSecurityException; +import java.security.KeyPair; +import java.security.Signature; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.Random; + +import junit.framework.TestCase; +import net.sf.briar.TestUtils; +import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.api.crypto.MessageDigest; +import net.sf.briar.api.protocol.Author; +import net.sf.briar.api.protocol.AuthorId; +import net.sf.briar.api.protocol.Batch; +import net.sf.briar.api.protocol.BatchId; +import net.sf.briar.api.protocol.Group; +import net.sf.briar.api.protocol.GroupId; +import net.sf.briar.api.protocol.Message; +import net.sf.briar.api.protocol.MessageId; +import net.sf.briar.api.protocol.UnverifiedBatch; +import net.sf.briar.crypto.CryptoModule; + +import org.jmock.Expectations; +import org.jmock.Mockery; +import org.junit.Test; + +import com.google.inject.Guice; +import com.google.inject.Injector; + +public class UnverifiedBatchImplTest extends TestCase { + + private final CryptoComponent crypto; + private final byte[] raw, raw1; + private final String subject; + private final long timestamp; + + public UnverifiedBatchImplTest() { + super(); + Injector i = Guice.createInjector(new CryptoModule()); + crypto = i.getInstance(CryptoComponent.class); + Random r = new Random(); + raw = new byte[123]; + r.nextBytes(raw); + raw1 = new byte[1234]; + r.nextBytes(raw1); + subject = "Unit tests are exciting"; + timestamp = System.currentTimeMillis(); + } + + @Test + public void testIds() throws Exception { + // Calculate the expected batch and message IDs + MessageDigest messageDigest = crypto.getMessageDigest(); + messageDigest.update(raw); + messageDigest.update(raw1); + BatchId batchId = new BatchId(messageDigest.digest()); + messageDigest.update(raw); + MessageId messageId = new MessageId(messageDigest.digest()); + messageDigest.update(raw1); + MessageId messageId1 = new MessageId(messageDigest.digest()); + // Verify the batch + Mockery context = new Mockery(); + final UnverifiedMessage message = + context.mock(UnverifiedMessage.class, "message"); + final UnverifiedMessage message1 = + context.mock(UnverifiedMessage.class, "message1"); + context.checking(new Expectations() {{ + // First message + oneOf(message).getRaw(); + will(returnValue(raw)); + oneOf(message).getAuthor(); + will(returnValue(null)); + oneOf(message).getGroup(); + will(returnValue(null)); + oneOf(message).getParent(); + will(returnValue(null)); + oneOf(message).getSubject(); + will(returnValue(subject)); + oneOf(message).getTimestamp(); + will(returnValue(timestamp)); + oneOf(message).getBodyStart(); + will(returnValue(10)); + oneOf(message).getBodyLength(); + will(returnValue(100)); + // Second message + oneOf(message1).getRaw(); + will(returnValue(raw1)); + oneOf(message1).getAuthor(); + will(returnValue(null)); + oneOf(message1).getGroup(); + will(returnValue(null)); + oneOf(message1).getParent(); + will(returnValue(null)); + oneOf(message1).getSubject(); + will(returnValue(subject)); + oneOf(message1).getTimestamp(); + will(returnValue(timestamp)); + oneOf(message1).getBodyStart(); + will(returnValue(10)); + oneOf(message1).getBodyLength(); + will(returnValue(1000)); + }}); + Collection<UnverifiedMessage> messages = + Arrays.asList(new UnverifiedMessage[] {message, message1}); + UnverifiedBatch batch = new UnverifiedBatchImpl(crypto, messages); + Batch verifiedBatch = batch.verify(); + // Check that the batch and message IDs match + assertEquals(batchId, verifiedBatch.getId()); + Collection<Message> verifiedMessages = verifiedBatch.getMessages(); + assertEquals(2, verifiedMessages.size()); + Iterator<Message> it = verifiedMessages.iterator(); + Message verifiedMessage = it.next(); + assertEquals(messageId, verifiedMessage.getId()); + Message verifiedMessage1 = it.next(); + assertEquals(messageId1, verifiedMessage1.getId()); + context.assertIsSatisfied(); + } + + @Test + public void testSignatures() throws Exception { + final int signedByAuthor = 100, signedByGroup = 110; + final KeyPair authorKeyPair = crypto.generateKeyPair(); + final KeyPair groupKeyPair = crypto.generateKeyPair(); + Signature signature = crypto.getSignature(); + // Calculate the expected author and group signatures + signature.initSign(authorKeyPair.getPrivate()); + signature.update(raw, 0, signedByAuthor); + final byte[] authorSignature = signature.sign(); + signature.initSign(groupKeyPair.getPrivate()); + signature.update(raw, 0, signedByGroup); + final byte[] groupSignature = signature.sign(); + // Verify the batch + Mockery context = new Mockery(); + final UnverifiedMessage message = + context.mock(UnverifiedMessage.class, "message"); + final Author author = context.mock(Author.class); + final Group group = context.mock(Group.class); + final UnverifiedMessage message1 = + context.mock(UnverifiedMessage.class, "message1"); + context.checking(new Expectations() {{ + // First message + oneOf(message).getRaw(); + will(returnValue(raw)); + oneOf(message).getAuthor(); + will(returnValue(author)); + oneOf(author).getPublicKey(); + will(returnValue(authorKeyPair.getPublic().getEncoded())); + oneOf(message).getLengthSignedByAuthor(); + will(returnValue(signedByAuthor)); + oneOf(message).getAuthorSignature(); + will(returnValue(authorSignature)); + oneOf(message).getGroup(); + will(returnValue(group)); + exactly(2).of(group).getPublicKey(); + will(returnValue(groupKeyPair.getPublic().getEncoded())); + oneOf(message).getLengthSignedByGroup(); + will(returnValue(signedByGroup)); + oneOf(message).getGroupSignature(); + will(returnValue(groupSignature)); + oneOf(author).getId(); + will(returnValue(new AuthorId(TestUtils.getRandomId()))); + oneOf(group).getId(); + will(returnValue(new GroupId(TestUtils.getRandomId()))); + oneOf(message).getParent(); + will(returnValue(null)); + oneOf(message).getSubject(); + will(returnValue(subject)); + oneOf(message).getTimestamp(); + will(returnValue(timestamp)); + oneOf(message).getBodyStart(); + will(returnValue(10)); + oneOf(message).getBodyLength(); + will(returnValue(100)); + // Second message + oneOf(message1).getRaw(); + will(returnValue(raw1)); + oneOf(message1).getAuthor(); + will(returnValue(null)); + oneOf(message1).getGroup(); + will(returnValue(null)); + oneOf(message1).getParent(); + will(returnValue(null)); + oneOf(message1).getSubject(); + will(returnValue(subject)); + oneOf(message1).getTimestamp(); + will(returnValue(timestamp)); + oneOf(message1).getBodyStart(); + will(returnValue(10)); + oneOf(message1).getBodyLength(); + will(returnValue(1000)); + }}); + Collection<UnverifiedMessage> messages = + Arrays.asList(new UnverifiedMessage[] {message, message1}); + UnverifiedBatch batch = new UnverifiedBatchImpl(crypto, messages); + batch.verify(); + context.assertIsSatisfied(); + } + + @Test + public void testExceptionThrownIfMessageIsModified() throws Exception { + final int signedByAuthor = 100; + final KeyPair authorKeyPair = crypto.generateKeyPair(); + Signature signature = crypto.getSignature(); + // Calculate the expected author signature + signature.initSign(authorKeyPair.getPrivate()); + signature.update(raw, 0, signedByAuthor); + final byte[] authorSignature = signature.sign(); + // Modify the message + raw[signedByAuthor / 2] ^= 0xff; + // Verify the batch + Mockery context = new Mockery(); + final UnverifiedMessage message = + context.mock(UnverifiedMessage.class, "message"); + final Author author = context.mock(Author.class); + final UnverifiedMessage message1 = + context.mock(UnverifiedMessage.class, "message1"); + context.checking(new Expectations() {{ + // First message - verification will fail at the author's signature + oneOf(message).getRaw(); + will(returnValue(raw)); + oneOf(message).getAuthor(); + will(returnValue(author)); + oneOf(author).getPublicKey(); + will(returnValue(authorKeyPair.getPublic().getEncoded())); + oneOf(message).getLengthSignedByAuthor(); + will(returnValue(signedByAuthor)); + oneOf(message).getAuthorSignature(); + will(returnValue(authorSignature)); + }}); + Collection<UnverifiedMessage> messages = + Arrays.asList(new UnverifiedMessage[] {message, message1}); + UnverifiedBatch batch = new UnverifiedBatchImpl(crypto, messages); + try { + batch.verify(); + fail(); + } catch(GeneralSecurityException expected) {} + context.assertIsSatisfied(); + } +}