From ae87100c8f48a1f120ae405fc7ef4cad7dc6834b Mon Sep 17 00:00:00 2001
From: akwizgran <akwizgran@users.sourceforge.net>
Date: Thu, 8 Dec 2011 12:51:34 +0000
Subject: [PATCH] Moved batch ID calculation off the IO thread.

---
 .../net/sf/briar/protocol/AckReader.java      |   6 +-
 .../net/sf/briar/protocol/BatchReader.java    |  25 +-
 .../net/sf/briar/protocol/MessageReader.java  |  23 +-
 .../net/sf/briar/protocol/OfferReader.java    |   6 +-
 .../net/sf/briar/protocol/ProtocolModule.java |   4 +-
 .../net/sf/briar/protocol/RequestReader.java  |   8 +-
 .../protocol/SubscriptionUpdateReader.java    |   6 +-
 .../briar/protocol/TransportUpdateReader.java |  19 +-
 .../protocol/UnverifiedBatchFactory.java      |   3 +-
 .../protocol/UnverifiedBatchFactoryImpl.java  |   5 +-
 .../briar/protocol/UnverifiedBatchImpl.java   |  16 +-
 test/build.xml                                |   1 +
 .../sf/briar/protocol/BatchReaderTest.java    |  51 +---
 .../protocol/UnverifiedBatchImplTest.java     | 242 ++++++++++++++++++
 14 files changed, 306 insertions(+), 109 deletions(-)
 create mode 100644 test/net/sf/briar/protocol/UnverifiedBatchImplTest.java

diff --git a/components/net/sf/briar/protocol/AckReader.java b/components/net/sf/briar/protocol/AckReader.java
index 977eaa8dab..0e89c6bcb1 100644
--- a/components/net/sf/briar/protocol/AckReader.java
+++ b/components/net/sf/briar/protocol/AckReader.java
@@ -1,5 +1,7 @@
 package net.sf.briar.protocol;
 
+import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -10,7 +12,6 @@ import net.sf.briar.api.FormatException;
 import net.sf.briar.api.protocol.Ack;
 import net.sf.briar.api.protocol.BatchId;
 import net.sf.briar.api.protocol.PacketFactory;
-import net.sf.briar.api.protocol.ProtocolConstants;
 import net.sf.briar.api.protocol.Types;
 import net.sf.briar.api.protocol.UniqueId;
 import net.sf.briar.api.serial.Consumer;
@@ -28,8 +29,7 @@ class AckReader implements ObjectReader<Ack> {
 
 	public Ack readObject(Reader r) throws IOException {
 		// Initialise the consumer
-		Consumer counting =
-			new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
+		Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
 		// Read the data
 		r.addConsumer(counting);
 		r.readStructId(Types.ACK);
diff --git a/components/net/sf/briar/protocol/BatchReader.java b/components/net/sf/briar/protocol/BatchReader.java
index 4864d6121a..55031796a1 100644
--- a/components/net/sf/briar/protocol/BatchReader.java
+++ b/components/net/sf/briar/protocol/BatchReader.java
@@ -1,52 +1,41 @@
 package net.sf.briar.protocol;
 
+import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
+
 import java.io.IOException;
 import java.util.List;
 
 import net.sf.briar.api.FormatException;
-import net.sf.briar.api.crypto.CryptoComponent;
-import net.sf.briar.api.crypto.MessageDigest;
-import net.sf.briar.api.protocol.BatchId;
-import net.sf.briar.api.protocol.ProtocolConstants;
 import net.sf.briar.api.protocol.Types;
 import net.sf.briar.api.protocol.UnverifiedBatch;
 import net.sf.briar.api.serial.Consumer;
 import net.sf.briar.api.serial.CountingConsumer;
-import net.sf.briar.api.serial.DigestingConsumer;
 import net.sf.briar.api.serial.ObjectReader;
 import net.sf.briar.api.serial.Reader;
 
 class BatchReader implements ObjectReader<UnverifiedBatch> {
 
-	private final MessageDigest messageDigest;
 	private final ObjectReader<UnverifiedMessage> messageReader;
 	private final UnverifiedBatchFactory batchFactory;
 
-	BatchReader(CryptoComponent crypto,
-			ObjectReader<UnverifiedMessage> messageReader,
+	BatchReader(ObjectReader<UnverifiedMessage> messageReader,
 			UnverifiedBatchFactory batchFactory) {
-		messageDigest = crypto.getMessageDigest();
 		this.messageReader = messageReader;
 		this.batchFactory = batchFactory;
 	}
 
 	public UnverifiedBatch readObject(Reader r) throws IOException {
-		// Initialise the consumers
-		Consumer counting =
-			new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
-		DigestingConsumer digesting = new DigestingConsumer(messageDigest);
-		// Read and digest the data
+		// Initialise the consumer
+		Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
+		// Read the data
 		r.addConsumer(counting);
-		r.addConsumer(digesting);
 		r.readStructId(Types.BATCH);
 		r.addObjectReader(Types.MESSAGE, messageReader);
 		List<UnverifiedMessage> messages = r.readList(UnverifiedMessage.class);
 		r.removeObjectReader(Types.MESSAGE);
-		r.removeConsumer(digesting);
 		r.removeConsumer(counting);
 		if(messages.isEmpty()) throw new FormatException();
 		// Build and return the batch
-		BatchId id = new BatchId(messageDigest.digest());
-		return batchFactory.createUnverifiedBatch(id, messages);
+		return batchFactory.createUnverifiedBatch( messages);
 	}
 }
diff --git a/components/net/sf/briar/protocol/MessageReader.java b/components/net/sf/briar/protocol/MessageReader.java
index 4c7b8a1a6d..7a4ddc22f9 100644
--- a/components/net/sf/briar/protocol/MessageReader.java
+++ b/components/net/sf/briar/protocol/MessageReader.java
@@ -1,12 +1,17 @@
 package net.sf.briar.protocol;
 
+import static net.sf.briar.api.protocol.ProtocolConstants.MAX_BODY_LENGTH;
+import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
+import static net.sf.briar.api.protocol.ProtocolConstants.MAX_SIGNATURE_LENGTH;
+import static net.sf.briar.api.protocol.ProtocolConstants.MAX_SUBJECT_LENGTH;
+import static net.sf.briar.api.protocol.ProtocolConstants.SALT_LENGTH;
+
 import java.io.IOException;
 
 import net.sf.briar.api.FormatException;
 import net.sf.briar.api.protocol.Author;
 import net.sf.briar.api.protocol.Group;
 import net.sf.briar.api.protocol.MessageId;
-import net.sf.briar.api.protocol.ProtocolConstants;
 import net.sf.briar.api.protocol.Types;
 import net.sf.briar.api.protocol.UniqueId;
 import net.sf.briar.api.serial.CopyingConsumer;
@@ -27,8 +32,7 @@ class MessageReader implements ObjectReader<UnverifiedMessage> {
 
 	public UnverifiedMessage readObject(Reader r) throws IOException {
 		CopyingConsumer copying = new CopyingConsumer();
-		CountingConsumer counting =
-			new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
+		CountingConsumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
 		r.addConsumer(copying);
 		r.addConsumer(counting);
 		// Read the initial tag
@@ -61,16 +65,15 @@ class MessageReader implements ObjectReader<UnverifiedMessage> {
 			r.removeObjectReader(Types.AUTHOR);
 		}
 		// Read the subject
-		String subject = r.readString(ProtocolConstants.MAX_SUBJECT_LENGTH);
+		String subject = r.readString(MAX_SUBJECT_LENGTH);
 		// Read the timestamp
 		long timestamp = r.readInt64();
 		if(timestamp < 0L) throw new FormatException();
 		// Read the salt
-		byte[] salt = r.readBytes(ProtocolConstants.SALT_LENGTH);
-		if(salt.length != ProtocolConstants.SALT_LENGTH)
-			throw new FormatException();
+		byte[] salt = r.readBytes(SALT_LENGTH);
+		if(salt.length != SALT_LENGTH) throw new FormatException();
 		// Read the message body
-		byte[] body = r.readBytes(ProtocolConstants.MAX_BODY_LENGTH);
+		byte[] body = r.readBytes(MAX_BODY_LENGTH);
 		// Record the offset of the body within the message
 		int bodyStart = (int) counting.getCount() - body.length;
 		// Record the length of the data covered by the author's signature
@@ -78,13 +81,13 @@ class MessageReader implements ObjectReader<UnverifiedMessage> {
 		// Read the author's signature, if there is one
 		byte[] authorSig = null;
 		if(author == null) r.readNull();
-		else authorSig = r.readBytes(ProtocolConstants.MAX_SIGNATURE_LENGTH);
+		else authorSig = r.readBytes(MAX_SIGNATURE_LENGTH);
 		// Record the length of the data covered by the group's signature
 		int signedByGroup = (int) counting.getCount();
 		// Read the group's signature, if there is one
 		byte[] groupSig = null;
 		if(group == null || group.getPublicKey() == null) r.readNull();
-		else groupSig = r.readBytes(ProtocolConstants.MAX_SIGNATURE_LENGTH);
+		else groupSig = r.readBytes(MAX_SIGNATURE_LENGTH);
 		// That's all, folks
 		r.removeConsumer(counting);
 		r.removeConsumer(copying);
diff --git a/components/net/sf/briar/protocol/OfferReader.java b/components/net/sf/briar/protocol/OfferReader.java
index d7e7dcd878..48e3edf5ed 100644
--- a/components/net/sf/briar/protocol/OfferReader.java
+++ b/components/net/sf/briar/protocol/OfferReader.java
@@ -1,5 +1,7 @@
 package net.sf.briar.protocol;
 
+import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
+
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -10,7 +12,6 @@ import net.sf.briar.api.FormatException;
 import net.sf.briar.api.protocol.MessageId;
 import net.sf.briar.api.protocol.Offer;
 import net.sf.briar.api.protocol.PacketFactory;
-import net.sf.briar.api.protocol.ProtocolConstants;
 import net.sf.briar.api.protocol.Types;
 import net.sf.briar.api.protocol.UniqueId;
 import net.sf.briar.api.serial.Consumer;
@@ -28,8 +29,7 @@ class OfferReader implements ObjectReader<Offer> {
 
 	public Offer readObject(Reader r) throws IOException {
 		// Initialise the consumer
-		Consumer counting =
-			new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
+		Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
 		// Read the data
 		r.addConsumer(counting);
 		r.readStructId(Types.OFFER);
diff --git a/components/net/sf/briar/protocol/ProtocolModule.java b/components/net/sf/briar/protocol/ProtocolModule.java
index d3eda315b5..b85e75d535 100644
--- a/components/net/sf/briar/protocol/ProtocolModule.java
+++ b/components/net/sf/briar/protocol/ProtocolModule.java
@@ -51,10 +51,10 @@ public class ProtocolModule extends AbstractModule {
 	}
 
 	@Provides
-	ObjectReader<UnverifiedBatch> getBatchReader(CryptoComponent crypto,
+	ObjectReader<UnverifiedBatch> getBatchReader(
 			ObjectReader<UnverifiedMessage> messageReader,
 			UnverifiedBatchFactory batchFactory) {
-		return new BatchReader(crypto, messageReader, batchFactory);
+		return new BatchReader(messageReader, batchFactory);
 	}
 
 	@Provides
diff --git a/components/net/sf/briar/protocol/RequestReader.java b/components/net/sf/briar/protocol/RequestReader.java
index cab6d20d7a..005bc0124e 100644
--- a/components/net/sf/briar/protocol/RequestReader.java
+++ b/components/net/sf/briar/protocol/RequestReader.java
@@ -1,11 +1,12 @@
 package net.sf.briar.protocol;
 
+import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
+
 import java.io.IOException;
 import java.util.BitSet;
 
 import net.sf.briar.api.FormatException;
 import net.sf.briar.api.protocol.PacketFactory;
-import net.sf.briar.api.protocol.ProtocolConstants;
 import net.sf.briar.api.protocol.Request;
 import net.sf.briar.api.protocol.Types;
 import net.sf.briar.api.serial.Consumer;
@@ -23,14 +24,13 @@ class RequestReader implements ObjectReader<Request> {
 
 	public Request readObject(Reader r) throws IOException {
 		// Initialise the consumer
-		Consumer counting =
-			new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
+		Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
 		// Read the data
 		r.addConsumer(counting);
 		r.readStructId(Types.REQUEST);
 		int padding = r.readUint7();
 		if(padding > 7) throw new FormatException();
-		byte[] bitmap = r.readBytes(ProtocolConstants.MAX_PACKET_LENGTH);
+		byte[] bitmap = r.readBytes(MAX_PACKET_LENGTH);
 		r.removeConsumer(counting);
 		// Convert the bitmap into a BitSet
 		int length = bitmap.length * 8 - padding;
diff --git a/components/net/sf/briar/protocol/SubscriptionUpdateReader.java b/components/net/sf/briar/protocol/SubscriptionUpdateReader.java
index 74b5508988..2e2dd26543 100644
--- a/components/net/sf/briar/protocol/SubscriptionUpdateReader.java
+++ b/components/net/sf/briar/protocol/SubscriptionUpdateReader.java
@@ -1,12 +1,13 @@
 package net.sf.briar.protocol;
 
+import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
+
 import java.io.IOException;
 import java.util.Map;
 
 import net.sf.briar.api.FormatException;
 import net.sf.briar.api.protocol.Group;
 import net.sf.briar.api.protocol.PacketFactory;
-import net.sf.briar.api.protocol.ProtocolConstants;
 import net.sf.briar.api.protocol.SubscriptionUpdate;
 import net.sf.briar.api.protocol.Types;
 import net.sf.briar.api.serial.Consumer;
@@ -27,8 +28,7 @@ class SubscriptionUpdateReader implements ObjectReader<SubscriptionUpdate> {
 
 	public SubscriptionUpdate readObject(Reader r) throws IOException {
 		// Initialise the consumer
-		Consumer counting =
-			new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
+		Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
 		// Read the data
 		r.addConsumer(counting);
 		r.readStructId(Types.SUBSCRIPTION_UPDATE);
diff --git a/components/net/sf/briar/protocol/TransportUpdateReader.java b/components/net/sf/briar/protocol/TransportUpdateReader.java
index 7b52cec7da..bbea4948f7 100644
--- a/components/net/sf/briar/protocol/TransportUpdateReader.java
+++ b/components/net/sf/briar/protocol/TransportUpdateReader.java
@@ -1,5 +1,10 @@
 package net.sf.briar.protocol;
 
+import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH;
+import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PROPERTIES_PER_TRANSPORT;
+import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PROPERTY_LENGTH;
+import static net.sf.briar.api.protocol.ProtocolConstants.MAX_TRANSPORTS;
+
 import java.io.IOException;
 import java.util.Collection;
 import java.util.HashSet;
@@ -8,7 +13,6 @@ import java.util.Set;
 
 import net.sf.briar.api.FormatException;
 import net.sf.briar.api.protocol.PacketFactory;
-import net.sf.briar.api.protocol.ProtocolConstants;
 import net.sf.briar.api.protocol.Transport;
 import net.sf.briar.api.protocol.TransportId;
 import net.sf.briar.api.protocol.TransportIndex;
@@ -32,16 +36,14 @@ class TransportUpdateReader implements ObjectReader<TransportUpdate> {
 
 	public TransportUpdate readObject(Reader r) throws IOException {
 		// Initialise the consumer
-		Consumer counting =
-			new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH);
+		Consumer counting = new CountingConsumer(MAX_PACKET_LENGTH);
 		// Read the data
 		r.addConsumer(counting);
 		r.readStructId(Types.TRANSPORT_UPDATE);
 		r.addObjectReader(Types.TRANSPORT, transportReader);
 		Collection<Transport> transports = r.readList(Transport.class);
 		r.removeObjectReader(Types.TRANSPORT);
-		if(transports.size() > ProtocolConstants.MAX_TRANSPORTS)
-			throw new FormatException();
+		if(transports.size() > MAX_TRANSPORTS) throw new FormatException();
 		long timestamp = r.readInt64();
 		r.removeConsumer(counting);
 		// Check for duplicate IDs or indices
@@ -65,14 +67,13 @@ class TransportUpdateReader implements ObjectReader<TransportUpdate> {
 			TransportId id = new TransportId(b);
 			// Read the index
 			int i = r.readInt32();
-			if(i < 0 || i >= ProtocolConstants.MAX_TRANSPORTS)
-				throw new FormatException();
+			if(i < 0 || i >= MAX_TRANSPORTS) throw new FormatException();
 			TransportIndex index = new TransportIndex(i);
 			// Read the properties
-			r.setMaxStringLength(ProtocolConstants.MAX_PROPERTY_LENGTH);
+			r.setMaxStringLength(MAX_PROPERTY_LENGTH);
 			Map<String, String> m = r.readMap(String.class, String.class);
 			r.resetMaxStringLength();
-			if(m.size() > ProtocolConstants.MAX_PROPERTIES_PER_TRANSPORT)
+			if(m.size() > MAX_PROPERTIES_PER_TRANSPORT)
 				throw new FormatException();
 			return new Transport(id, index, m);
 		}
diff --git a/components/net/sf/briar/protocol/UnverifiedBatchFactory.java b/components/net/sf/briar/protocol/UnverifiedBatchFactory.java
index c420dd6b68..93e20ac8aa 100644
--- a/components/net/sf/briar/protocol/UnverifiedBatchFactory.java
+++ b/components/net/sf/briar/protocol/UnverifiedBatchFactory.java
@@ -2,11 +2,10 @@ package net.sf.briar.protocol;
 
 import java.util.Collection;
 
-import net.sf.briar.api.protocol.BatchId;
 import net.sf.briar.api.protocol.UnverifiedBatch;
 
 interface UnverifiedBatchFactory {
 
-	UnverifiedBatch createUnverifiedBatch(BatchId id,
+	UnverifiedBatch createUnverifiedBatch(
 			Collection<UnverifiedMessage> messages);
 }
diff --git a/components/net/sf/briar/protocol/UnverifiedBatchFactoryImpl.java b/components/net/sf/briar/protocol/UnverifiedBatchFactoryImpl.java
index b762a278d1..980c09db57 100644
--- a/components/net/sf/briar/protocol/UnverifiedBatchFactoryImpl.java
+++ b/components/net/sf/briar/protocol/UnverifiedBatchFactoryImpl.java
@@ -3,7 +3,6 @@ package net.sf.briar.protocol;
 import java.util.Collection;
 
 import net.sf.briar.api.crypto.CryptoComponent;
-import net.sf.briar.api.protocol.BatchId;
 import net.sf.briar.api.protocol.UnverifiedBatch;
 
 import com.google.inject.Inject;
@@ -17,8 +16,8 @@ class UnverifiedBatchFactoryImpl implements UnverifiedBatchFactory {
 		this.crypto = crypto;
 	}
 
-	public UnverifiedBatch createUnverifiedBatch(BatchId id,
+	public UnverifiedBatch createUnverifiedBatch(
 			Collection<UnverifiedMessage> messages) {
-		return new UnverifiedBatchImpl(crypto, id, messages);
+		return new UnverifiedBatchImpl(crypto, messages);
 	}
 }
diff --git a/components/net/sf/briar/protocol/UnverifiedBatchImpl.java b/components/net/sf/briar/protocol/UnverifiedBatchImpl.java
index 8665de30a4..a1e24354c8 100644
--- a/components/net/sf/briar/protocol/UnverifiedBatchImpl.java
+++ b/components/net/sf/briar/protocol/UnverifiedBatchImpl.java
@@ -24,32 +24,34 @@ import net.sf.briar.api.protocol.UnverifiedBatch;
 class UnverifiedBatchImpl implements UnverifiedBatch {
 
 	private final CryptoComponent crypto;
-	private final BatchId id;
 	private final Collection<UnverifiedMessage> messages;
+	private final MessageDigest batchDigest, messageDigest;
 
-	// Initialise lazily - the batch may be empty or contain unsigned messages
-	private MessageDigest messageDigest = null;
+	// Initialise lazily - the batch may contain unsigned messages
 	private KeyParser keyParser = null;
 	private Signature signature = null;
 
-	UnverifiedBatchImpl(CryptoComponent crypto, BatchId id,
+	UnverifiedBatchImpl(CryptoComponent crypto,
 			Collection<UnverifiedMessage> messages) {
 		this.crypto = crypto;
-		this.id = id;
 		this.messages = messages;
+		batchDigest = crypto.getMessageDigest();
+		messageDigest = crypto.getMessageDigest();
 	}
 
 	public Batch verify() throws GeneralSecurityException {
 		List<Message> verified = new ArrayList<Message>();
 		for(UnverifiedMessage m : messages) verified.add(verify(m));
+		BatchId id = new BatchId(batchDigest.digest());
 		return new BatchImpl(id, Collections.unmodifiableList(verified));
 	}
 
 	private Message verify(UnverifiedMessage m)
 	throws GeneralSecurityException {
-		// Hash the message, including the signatures, to get the message ID
+		// The batch ID is the hash of the concatenated messages
 		byte[] raw = m.getRaw();
-		if(messageDigest == null) messageDigest = crypto.getMessageDigest();
+		batchDigest.update(raw);
+		// Hash the message, including the signatures, to get the message ID
 		messageDigest.update(raw);
 		MessageId id = new MessageId(messageDigest.digest());
 		// Verify the author's signature, if there is one
diff --git a/test/build.xml b/test/build.xml
index d865f5bd5a..a37b43c243 100644
--- a/test/build.xml
+++ b/test/build.xml
@@ -43,6 +43,7 @@
 			<test name='net.sf.briar.protocol.ProtocolReadWriteTest'/>
 			<test name='net.sf.briar.protocol.ProtocolWriterImplTest'/>
 			<test name='net.sf.briar.protocol.RequestReaderTest'/>
+			<test name='net.sf.briar.protocol.UnverifiedBatchImplTest'/>
 			<test name='net.sf.briar.serial.ReaderImplTest'/>
 			<test name='net.sf.briar.serial.WriterImplTest'/>
 			<test name='net.sf.briar.setup.SetupWorkerTest'/>
diff --git a/test/net/sf/briar/protocol/BatchReaderTest.java b/test/net/sf/briar/protocol/BatchReaderTest.java
index f24722a1ff..4bd87ad252 100644
--- a/test/net/sf/briar/protocol/BatchReaderTest.java
+++ b/test/net/sf/briar/protocol/BatchReaderTest.java
@@ -7,9 +7,6 @@ import java.util.Collections;
 
 import junit.framework.TestCase;
 import net.sf.briar.api.FormatException;
-import net.sf.briar.api.crypto.CryptoComponent;
-import net.sf.briar.api.crypto.MessageDigest;
-import net.sf.briar.api.protocol.BatchId;
 import net.sf.briar.api.protocol.ProtocolConstants;
 import net.sf.briar.api.protocol.Types;
 import net.sf.briar.api.protocol.UnverifiedBatch;
@@ -18,7 +15,6 @@ import net.sf.briar.api.serial.Reader;
 import net.sf.briar.api.serial.ReaderFactory;
 import net.sf.briar.api.serial.Writer;
 import net.sf.briar.api.serial.WriterFactory;
-import net.sf.briar.crypto.CryptoModule;
 import net.sf.briar.serial.SerialModule;
 
 import org.jmock.Expectations;
@@ -32,18 +28,15 @@ public class BatchReaderTest extends TestCase {
 
 	private final ReaderFactory readerFactory;
 	private final WriterFactory writerFactory;
-	private final CryptoComponent crypto;
 	private final Mockery context;
 	private final UnverifiedMessage message;
 	private final ObjectReader<UnverifiedMessage> messageReader;
 
 	public BatchReaderTest() throws Exception {
 		super();
-		Injector i = Guice.createInjector(new SerialModule(),
-				new CryptoModule());
+		Injector i = Guice.createInjector(new SerialModule());
 		readerFactory = i.getInstance(ReaderFactory.class);
 		writerFactory = i.getInstance(WriterFactory.class);
-		crypto = i.getInstance(CryptoComponent.class);
 		context = new Mockery();
 		message = context.mock(UnverifiedMessage.class);
 		messageReader = new TestMessageReader();
@@ -53,8 +46,7 @@ public class BatchReaderTest extends TestCase {
 	public void testFormatExceptionIfBatchIsTooLarge() throws Exception {
 		UnverifiedBatchFactory batchFactory =
 			context.mock(UnverifiedBatchFactory.class);
-		BatchReader batchReader = new BatchReader(crypto, messageReader,
-				batchFactory);
+		BatchReader batchReader = new BatchReader(messageReader, batchFactory);
 
 		byte[] b = createBatch(ProtocolConstants.MAX_PACKET_LENGTH + 1);
 		ByteArrayInputStream in = new ByteArrayInputStream(b);
@@ -72,12 +64,11 @@ public class BatchReaderTest extends TestCase {
 	public void testNoFormatExceptionIfBatchIsMaximumSize() throws Exception {
 		final UnverifiedBatchFactory batchFactory =
 			context.mock(UnverifiedBatchFactory.class);
-		BatchReader batchReader = new BatchReader(crypto, messageReader,
-				batchFactory);
+		BatchReader batchReader = new BatchReader(messageReader, batchFactory);
 		final UnverifiedBatch batch = context.mock(UnverifiedBatch.class);
 		context.checking(new Expectations() {{
-			oneOf(batchFactory).createUnverifiedBatch(with(any(BatchId.class)),
-					with(Collections.singletonList(message)));
+			oneOf(batchFactory).createUnverifiedBatch(
+					Collections.singletonList(message));
 			will(returnValue(batch));
 		}});
 
@@ -91,41 +82,11 @@ public class BatchReaderTest extends TestCase {
 		context.assertIsSatisfied();
 	}
 
-	@Test
-	public void testBatchId() throws Exception {
-		byte[] b = createBatch(ProtocolConstants.MAX_PACKET_LENGTH);
-		// Calculate the expected batch ID
-		MessageDigest messageDigest = crypto.getMessageDigest();
-		messageDigest.update(b);
-		final BatchId id = new BatchId(messageDigest.digest());
-
-		final UnverifiedBatchFactory batchFactory =
-			context.mock(UnverifiedBatchFactory.class);
-		BatchReader batchReader = new BatchReader(crypto, messageReader,
-				batchFactory);
-		final UnverifiedBatch batch = context.mock(UnverifiedBatch.class);
-		context.checking(new Expectations() {{
-			// Check that the batch ID matches the expected ID
-			oneOf(batchFactory).createUnverifiedBatch(with(id),
-					with(Collections.singletonList(message)));
-			will(returnValue(batch));
-		}});
-
-		ByteArrayInputStream in = new ByteArrayInputStream(b);
-		Reader reader = readerFactory.createReader(in);
-		reader.addObjectReader(Types.BATCH, batchReader);
-
-		assertEquals(batch, reader.readStruct(Types.BATCH,
-				UnverifiedBatch.class));
-		context.assertIsSatisfied();
-	}
-
 	@Test
 	public void testEmptyBatch() throws Exception {
 		final UnverifiedBatchFactory batchFactory =
 			context.mock(UnverifiedBatchFactory.class);
-		BatchReader batchReader = new BatchReader(crypto, messageReader,
-				batchFactory);
+		BatchReader batchReader = new BatchReader(messageReader, batchFactory);
 
 		byte[] b = createEmptyBatch();
 		ByteArrayInputStream in = new ByteArrayInputStream(b);
diff --git a/test/net/sf/briar/protocol/UnverifiedBatchImplTest.java b/test/net/sf/briar/protocol/UnverifiedBatchImplTest.java
new file mode 100644
index 0000000000..c0924c22ab
--- /dev/null
+++ b/test/net/sf/briar/protocol/UnverifiedBatchImplTest.java
@@ -0,0 +1,242 @@
+package net.sf.briar.protocol;
+
+import java.security.GeneralSecurityException;
+import java.security.KeyPair;
+import java.security.Signature;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Random;
+
+import junit.framework.TestCase;
+import net.sf.briar.TestUtils;
+import net.sf.briar.api.crypto.CryptoComponent;
+import net.sf.briar.api.crypto.MessageDigest;
+import net.sf.briar.api.protocol.Author;
+import net.sf.briar.api.protocol.AuthorId;
+import net.sf.briar.api.protocol.Batch;
+import net.sf.briar.api.protocol.BatchId;
+import net.sf.briar.api.protocol.Group;
+import net.sf.briar.api.protocol.GroupId;
+import net.sf.briar.api.protocol.Message;
+import net.sf.briar.api.protocol.MessageId;
+import net.sf.briar.api.protocol.UnverifiedBatch;
+import net.sf.briar.crypto.CryptoModule;
+
+import org.jmock.Expectations;
+import org.jmock.Mockery;
+import org.junit.Test;
+
+import com.google.inject.Guice;
+import com.google.inject.Injector;
+
+public class UnverifiedBatchImplTest extends TestCase {
+
+	private final CryptoComponent crypto;
+	private final byte[] raw, raw1;
+	private final String subject;
+	private final long timestamp;
+
+	public UnverifiedBatchImplTest() {
+		super();
+		Injector i = Guice.createInjector(new CryptoModule());
+		crypto = i.getInstance(CryptoComponent.class);
+		Random r = new Random();
+		raw = new byte[123];
+		r.nextBytes(raw);
+		raw1 = new byte[1234];
+		r.nextBytes(raw1);
+		subject = "Unit tests are exciting";
+		timestamp = System.currentTimeMillis();
+	}
+
+	@Test
+	public void testIds() throws Exception {
+		// Calculate the expected batch and message IDs
+		MessageDigest messageDigest = crypto.getMessageDigest();
+		messageDigest.update(raw);
+		messageDigest.update(raw1);
+		BatchId batchId = new BatchId(messageDigest.digest());
+		messageDigest.update(raw);
+		MessageId messageId = new MessageId(messageDigest.digest());
+		messageDigest.update(raw1);
+		MessageId messageId1 = new MessageId(messageDigest.digest());
+		// Verify the batch
+		Mockery context = new Mockery();
+		final UnverifiedMessage message =
+			context.mock(UnverifiedMessage.class, "message");
+		final UnverifiedMessage message1 =
+			context.mock(UnverifiedMessage.class, "message1");
+		context.checking(new Expectations() {{
+			// First message
+			oneOf(message).getRaw();
+			will(returnValue(raw));
+			oneOf(message).getAuthor();
+			will(returnValue(null));
+			oneOf(message).getGroup();
+			will(returnValue(null));
+			oneOf(message).getParent();
+			will(returnValue(null));
+			oneOf(message).getSubject();
+			will(returnValue(subject));
+			oneOf(message).getTimestamp();
+			will(returnValue(timestamp));
+			oneOf(message).getBodyStart();
+			will(returnValue(10));
+			oneOf(message).getBodyLength();
+			will(returnValue(100));
+			// Second message
+			oneOf(message1).getRaw();
+			will(returnValue(raw1));
+			oneOf(message1).getAuthor();
+			will(returnValue(null));
+			oneOf(message1).getGroup();
+			will(returnValue(null));
+			oneOf(message1).getParent();
+			will(returnValue(null));
+			oneOf(message1).getSubject();
+			will(returnValue(subject));
+			oneOf(message1).getTimestamp();
+			will(returnValue(timestamp));
+			oneOf(message1).getBodyStart();
+			will(returnValue(10));
+			oneOf(message1).getBodyLength();
+			will(returnValue(1000));
+		}});
+		Collection<UnverifiedMessage> messages =
+			Arrays.asList(new UnverifiedMessage[] {message, message1});
+		UnverifiedBatch batch = new UnverifiedBatchImpl(crypto, messages);
+		Batch verifiedBatch = batch.verify();
+		// Check that the batch and message IDs match
+		assertEquals(batchId, verifiedBatch.getId());
+		Collection<Message> verifiedMessages = verifiedBatch.getMessages();
+		assertEquals(2, verifiedMessages.size());
+		Iterator<Message> it = verifiedMessages.iterator();
+		Message verifiedMessage = it.next();
+		assertEquals(messageId, verifiedMessage.getId());
+		Message verifiedMessage1 = it.next();
+		assertEquals(messageId1, verifiedMessage1.getId());
+		context.assertIsSatisfied();
+	}
+
+	@Test
+	public void testSignatures() throws Exception {
+		final int signedByAuthor = 100, signedByGroup = 110;
+		final KeyPair authorKeyPair = crypto.generateKeyPair();
+		final KeyPair groupKeyPair = crypto.generateKeyPair();
+		Signature signature = crypto.getSignature();
+		// Calculate the expected author and group signatures
+		signature.initSign(authorKeyPair.getPrivate());
+		signature.update(raw, 0, signedByAuthor);
+		final byte[] authorSignature = signature.sign();
+		signature.initSign(groupKeyPair.getPrivate());
+		signature.update(raw, 0, signedByGroup);
+		final byte[] groupSignature = signature.sign();
+		// Verify the batch
+		Mockery context = new Mockery();
+		final UnverifiedMessage message =
+			context.mock(UnverifiedMessage.class, "message");
+		final Author author = context.mock(Author.class);
+		final Group group = context.mock(Group.class);
+		final UnverifiedMessage message1 =
+			context.mock(UnverifiedMessage.class, "message1");
+		context.checking(new Expectations() {{
+			// First message
+			oneOf(message).getRaw();
+			will(returnValue(raw));
+			oneOf(message).getAuthor();
+			will(returnValue(author));
+			oneOf(author).getPublicKey();
+			will(returnValue(authorKeyPair.getPublic().getEncoded()));
+			oneOf(message).getLengthSignedByAuthor();
+			will(returnValue(signedByAuthor));
+			oneOf(message).getAuthorSignature();
+			will(returnValue(authorSignature));
+			oneOf(message).getGroup();
+			will(returnValue(group));
+			exactly(2).of(group).getPublicKey();
+			will(returnValue(groupKeyPair.getPublic().getEncoded()));
+			oneOf(message).getLengthSignedByGroup();
+			will(returnValue(signedByGroup));
+			oneOf(message).getGroupSignature();
+			will(returnValue(groupSignature));
+			oneOf(author).getId();
+			will(returnValue(new AuthorId(TestUtils.getRandomId())));
+			oneOf(group).getId();
+			will(returnValue(new GroupId(TestUtils.getRandomId())));
+			oneOf(message).getParent();
+			will(returnValue(null));
+			oneOf(message).getSubject();
+			will(returnValue(subject));
+			oneOf(message).getTimestamp();
+			will(returnValue(timestamp));
+			oneOf(message).getBodyStart();
+			will(returnValue(10));
+			oneOf(message).getBodyLength();
+			will(returnValue(100));
+			// Second message
+			oneOf(message1).getRaw();
+			will(returnValue(raw1));
+			oneOf(message1).getAuthor();
+			will(returnValue(null));
+			oneOf(message1).getGroup();
+			will(returnValue(null));
+			oneOf(message1).getParent();
+			will(returnValue(null));
+			oneOf(message1).getSubject();
+			will(returnValue(subject));
+			oneOf(message1).getTimestamp();
+			will(returnValue(timestamp));
+			oneOf(message1).getBodyStart();
+			will(returnValue(10));
+			oneOf(message1).getBodyLength();
+			will(returnValue(1000));
+		}});
+		Collection<UnverifiedMessage> messages =
+			Arrays.asList(new UnverifiedMessage[] {message, message1});
+		UnverifiedBatch batch = new UnverifiedBatchImpl(crypto, messages);
+		batch.verify();
+		context.assertIsSatisfied();
+	}
+
+	@Test
+	public void testExceptionThrownIfMessageIsModified() throws Exception {
+		final int signedByAuthor = 100;
+		final KeyPair authorKeyPair = crypto.generateKeyPair();
+		Signature signature = crypto.getSignature();
+		// Calculate the expected author signature
+		signature.initSign(authorKeyPair.getPrivate());
+		signature.update(raw, 0, signedByAuthor);
+		final byte[] authorSignature = signature.sign();
+		// Modify the message
+		raw[signedByAuthor / 2] ^= 0xff;
+		// Verify the batch
+		Mockery context = new Mockery();
+		final UnverifiedMessage message =
+			context.mock(UnverifiedMessage.class, "message");
+		final Author author = context.mock(Author.class);
+		final UnverifiedMessage message1 =
+			context.mock(UnverifiedMessage.class, "message1");
+		context.checking(new Expectations() {{
+			// First message - verification will fail at the author's signature
+			oneOf(message).getRaw();
+			will(returnValue(raw));
+			oneOf(message).getAuthor();
+			will(returnValue(author));
+			oneOf(author).getPublicKey();
+			will(returnValue(authorKeyPair.getPublic().getEncoded()));
+			oneOf(message).getLengthSignedByAuthor();
+			will(returnValue(signedByAuthor));
+			oneOf(message).getAuthorSignature();
+			will(returnValue(authorSignature));
+		}});
+		Collection<UnverifiedMessage> messages =
+			Arrays.asList(new UnverifiedMessage[] {message, message1});
+		UnverifiedBatch batch = new UnverifiedBatchImpl(crypto, messages);
+		try {
+			batch.verify();
+			fail();
+		} catch(GeneralSecurityException expected) {}
+		context.assertIsSatisfied();
+	}
+}
-- 
GitLab