From f7274208386941793073b14ecdb24064747e554b Mon Sep 17 00:00:00 2001
From: akwizgran <akwizgran@users.sourceforge.net>
Date: Wed, 20 Jul 2011 18:33:06 +0100
Subject: [PATCH] Removed signatures from headers and bundles, since the
 transport's authentication will make them redundant.

---
 .../net/sf/briar/protocol/BatchReader.java    |  23 +-
 .../sf/briar/protocol/BundleReaderImpl.java   |  16 +-
 .../sf/briar/protocol/BundleWriterImpl.java   |  46 +--
 .../net/sf/briar/protocol/HeaderReader.java   |  21 +-
 test/build.xml                                |   1 +
 .../briar/protocol/BundleReadWriteTest.java   |  85 ++----
 .../briar/protocol/BundleReaderImplTest.java  | 267 +++++++++++++++++-
 7 files changed, 319 insertions(+), 140 deletions(-)

diff --git a/components/net/sf/briar/protocol/BatchReader.java b/components/net/sf/briar/protocol/BatchReader.java
index bfb67d352f..a358c0a16e 100644
--- a/components/net/sf/briar/protocol/BatchReader.java
+++ b/components/net/sf/briar/protocol/BatchReader.java
@@ -3,9 +3,6 @@ package net.sf.briar.protocol;
 import java.io.IOException;
 import java.security.GeneralSecurityException;
 import java.security.MessageDigest;
-import java.security.PublicKey;
-import java.security.Signature;
-import java.security.SignatureException;
 import java.util.List;
 
 import net.sf.briar.api.protocol.Batch;
@@ -17,17 +14,12 @@ import net.sf.briar.api.serial.Reader;
 
 public class BatchReader implements ObjectReader<Batch> {
 
-	private final PublicKey publicKey;
-	private final Signature signature;
 	private final MessageDigest messageDigest;
 	private final ObjectReader<Message> messageReader;
 	private final BatchFactory batchFactory;
 
-	BatchReader(PublicKey publicKey, Signature signature,
-			MessageDigest messageDigest, ObjectReader<Message> messageReader,
-			BatchFactory batchFactory) {
-		this.publicKey = publicKey;
-		this.signature = signature;
+	BatchReader(MessageDigest messageDigest,
+			ObjectReader<Message> messageReader, BatchFactory batchFactory) {
 		this.messageDigest = messageDigest;
 		this.messageReader = messageReader;
 		this.batchFactory = batchFactory;
@@ -35,25 +27,18 @@ public class BatchReader implements ObjectReader<Batch> {
 
 	public Batch readObject(Reader reader) throws IOException,
 	GeneralSecurityException {
-		// Initialise the input stream
+		// Initialise the consumers
 		CountingConsumer counting = new CountingConsumer(Batch.MAX_SIZE);
 		DigestingConsumer digesting = new DigestingConsumer(messageDigest);
 		messageDigest.reset();
-		SigningConsumer signing = new SigningConsumer(signature);
-		signature.initVerify(publicKey);
-		// Read the signed data
+		// Read and digest the data
 		reader.addConsumer(counting);
 		reader.addConsumer(digesting);
-		reader.addConsumer(signing);
 		reader.addObjectReader(Tags.MESSAGE, messageReader);
 		List<Message> messages = reader.readList(Message.class);
 		reader.removeObjectReader(Tags.MESSAGE);
-		reader.removeConsumer(signing);
-		// Read and verify the signature
-		byte[] sig = reader.readRaw();
 		reader.removeConsumer(digesting);
 		reader.removeConsumer(counting);
-		if(!signature.verify(sig)) throw new SignatureException();
 		// Build and return the batch
 		BatchId id = new BatchId(messageDigest.digest());
 		return batchFactory.createBatch(id, messages);
diff --git a/components/net/sf/briar/protocol/BundleReaderImpl.java b/components/net/sf/briar/protocol/BundleReaderImpl.java
index db12a3cd9b..e82bdb6282 100644
--- a/components/net/sf/briar/protocol/BundleReaderImpl.java
+++ b/components/net/sf/briar/protocol/BundleReaderImpl.java
@@ -13,7 +13,7 @@ import net.sf.briar.api.serial.Reader;
 
 class BundleReaderImpl implements BundleReader {
 
-	private static enum State { START, FIRST_BATCH, MORE_BATCHES, END };
+	private static enum State { START, BATCHES, END };
 
 	private final Reader reader;
 	private final ObjectReader<Header> headerReader;
@@ -29,21 +29,19 @@ class BundleReaderImpl implements BundleReader {
 
 	public Header getHeader() throws IOException, GeneralSecurityException {
 		if(state != State.START) throw new IllegalStateException();
-		reader.addObjectReader(Tags.HEADER, headerReader);
 		reader.readUserDefinedTag(Tags.HEADER);
+		reader.addObjectReader(Tags.HEADER, headerReader);
 		Header h = reader.readUserDefinedObject(Tags.HEADER, Header.class);
 		reader.removeObjectReader(Tags.HEADER);
-		state = State.FIRST_BATCH;
+		// Expect a list of batches
+		reader.readListStart();
+		reader.addObjectReader(Tags.BATCH, batchReader);
+		state = State.BATCHES;
 		return h;
 	}
 
 	public Batch getNextBatch() throws IOException, GeneralSecurityException {
-		if(state == State.FIRST_BATCH) {
-			reader.readListStart();
-			reader.addObjectReader(Tags.BATCH, batchReader);
-			state = State.MORE_BATCHES;
-		}
-		if(state != State.MORE_BATCHES) throw new IllegalStateException();
+		if(state != State.BATCHES) throw new IllegalStateException();
 		if(reader.hasListEnd()) {
 			reader.removeObjectReader(Tags.BATCH);
 			reader.readListEnd();
diff --git a/components/net/sf/briar/protocol/BundleWriterImpl.java b/components/net/sf/briar/protocol/BundleWriterImpl.java
index e6ff92b7fe..ca4096d16e 100644
--- a/components/net/sf/briar/protocol/BundleWriterImpl.java
+++ b/components/net/sf/briar/protocol/BundleWriterImpl.java
@@ -2,10 +2,9 @@ package net.sf.briar.protocol;
 
 import java.io.IOException;
 import java.io.OutputStream;
+import java.security.DigestOutputStream;
 import java.security.GeneralSecurityException;
 import java.security.MessageDigest;
-import java.security.PrivateKey;
-import java.security.Signature;
 import java.util.Collection;
 import java.util.Map;
 
@@ -21,22 +20,17 @@ class BundleWriterImpl implements BundleWriter {
 
 	private static enum State { START, FIRST_BATCH, MORE_BATCHES, END };
 
-	private final SigningDigestingOutputStream out;
+	private final DigestOutputStream out;
 	private final Writer writer;
-	private final PrivateKey privateKey;
-	private final Signature signature;
 	private final MessageDigest messageDigest;
 	private final long capacity;
 	private State state = State.START;
 
 	BundleWriterImpl(OutputStream out, WriterFactory writerFactory,
-			PrivateKey privateKey, Signature signature,
 			MessageDigest messageDigest, long capacity) {
-		this.out =
-			new SigningDigestingOutputStream(out, signature, messageDigest);
+		this.out = new DigestOutputStream(out, messageDigest);
+		this.out.on(false); // Turn off the digest until we need it
 		writer = writerFactory.createWriter(this.out);
-		this.privateKey = privateKey;
-		this.signature = signature;
 		this.messageDigest = messageDigest;
 		this.capacity = capacity;
 	}
@@ -49,24 +43,13 @@ class BundleWriterImpl implements BundleWriter {
 			Map<String, String> transports) throws IOException,
 			GeneralSecurityException {
 		if(state != State.START) throw new IllegalStateException();
-		// Initialise the output stream
-		signature.initSign(privateKey);
 		// Write the initial tag
 		writer.writeUserDefinedTag(Tags.HEADER);
-		// Write the data to be signed
-		out.setSigning(true);
-		// Acks
+		// Write the data
 		writer.writeList(acks);
-		// Subs
 		writer.writeList(subs);
-		// Transports
 		writer.writeMap(transports);
-		// Timestamp
 		writer.writeInt64(System.currentTimeMillis());
-		out.setSigning(false);
-		// Create and write the signature
-		byte[] sig = signature.sign();
-		writer.writeRaw(sig);
 		// Expect a (possibly empty) list of batches
 		state = State.FIRST_BATCH;
 	}
@@ -78,26 +61,21 @@ class BundleWriterImpl implements BundleWriter {
 			state = State.MORE_BATCHES;
 		}
 		if(state != State.MORE_BATCHES) throw new IllegalStateException();
-		// Initialise the output stream
-		signature.initSign(privateKey);
-		messageDigest.reset();
 		// Write the initial tag
 		writer.writeUserDefinedTag(Tags.BATCH);
-		// Write the data to be signed
-		out.setDigesting(true);
-		out.setSigning(true);
+		// Start digesting
+		messageDigest.reset();
+		out.on(true);
+		// Write the data
 		writer.writeListStart();
-		// Bypass the writer and write the raw messages directly
+		// Bypass the writer and write each raw message directly
 		for(Raw message : messages) {
 			writer.writeUserDefinedTag(Tags.MESSAGE);
 			out.write(message.getBytes());
 		}
 		writer.writeListEnd();
-		out.setSigning(false);
-		// Create and write the signature
-		byte[] sig = signature.sign();
-		writer.writeRaw(sig);
-		out.setDigesting(false);
+		// Stop digesting
+		out.on(false);
 		// Calculate and return the ID
 		return new BatchId(messageDigest.digest());
 	}
diff --git a/components/net/sf/briar/protocol/HeaderReader.java b/components/net/sf/briar/protocol/HeaderReader.java
index aeaf93e2b9..7db1158300 100644
--- a/components/net/sf/briar/protocol/HeaderReader.java
+++ b/components/net/sf/briar/protocol/HeaderReader.java
@@ -2,9 +2,6 @@ package net.sf.briar.protocol;
 
 import java.io.IOException;
 import java.security.GeneralSecurityException;
-import java.security.PublicKey;
-import java.security.Signature;
-import java.security.SignatureException;
 import java.util.Collection;
 import java.util.Map;
 
@@ -18,26 +15,17 @@ import net.sf.briar.api.serial.Reader;
 
 class HeaderReader implements ObjectReader<Header> {
 
-	private final PublicKey publicKey;
-	private final Signature signature;
 	private final HeaderFactory headerFactory;
 
-	HeaderReader(PublicKey publicKey, Signature signature,
-			HeaderFactory headerFactory) {
-		this.publicKey = publicKey;
-		this.signature = signature;
+	HeaderReader(HeaderFactory headerFactory) {
 		this.headerFactory = headerFactory;
 	}
 
 	public Header readObject(Reader reader) throws IOException,
 	GeneralSecurityException {
-		// Initialise the input stream
+		// Initialise and add the consumer
 		CountingConsumer counting = new CountingConsumer(Header.MAX_SIZE);
-		SigningConsumer signing = new SigningConsumer(signature);
-		signature.initVerify(publicKey);
-		// Read the signed data
 		reader.addConsumer(counting);
-		reader.addConsumer(signing);
 		// Acks
 		reader.addObjectReader(Tags.BATCH_ID, new BatchIdReader());
 		Collection<BatchId> acks = reader.readList(BatchId.class);
@@ -52,11 +40,8 @@ class HeaderReader implements ObjectReader<Header> {
 		// Timestamp
 		long timestamp = reader.readInt64();
 		if(timestamp < 0L) throw new FormatException();
-		reader.removeConsumer(signing);
-		// Read and verify the signature
-		byte[] sig = reader.readRaw();
+		// Remove the consumer
 		reader.removeConsumer(counting);
-		if(!signature.verify(sig)) throw new SignatureException();
 		// Build and return the header
 		return headerFactory.createHeader(acks, subs, transports, timestamp);
 	}
diff --git a/test/build.xml b/test/build.xml
index f51ac300af..5d0d255b7b 100644
--- a/test/build.xml
+++ b/test/build.xml
@@ -20,6 +20,7 @@
 			<test name='net.sf.briar.i18n.FontManagerTest'/>
 			<test name='net.sf.briar.i18n.I18nTest'/>
 			<test name='net.sf.briar.invitation.InvitationWorkerTest'/>
+			<test name='net.sf.briar.protocol.BundleReaderImplTest'/>
 			<test name='net.sf.briar.protocol.BundleReadWriteTest'/>
 			<test name='net.sf.briar.protocol.ConsumersTest'/>
 			<test name='net.sf.briar.protocol.SigningDigestingOutputStreamTest'/>
diff --git a/test/net/sf/briar/protocol/BundleReadWriteTest.java b/test/net/sf/briar/protocol/BundleReadWriteTest.java
index c20e261abd..1a32d6e73a 100644
--- a/test/net/sf/briar/protocol/BundleReadWriteTest.java
+++ b/test/net/sf/briar/protocol/BundleReadWriteTest.java
@@ -3,8 +3,6 @@ package net.sf.briar.protocol;
 import java.io.File;
 import java.io.FileInputStream;
 import java.io.FileOutputStream;
-import java.io.RandomAccessFile;
-import java.security.GeneralSecurityException;
 import java.security.KeyFactory;
 import java.security.KeyPair;
 import java.security.KeyPairGenerator;
@@ -66,25 +64,23 @@ public class BundleReadWriteTest extends TestCase {
 	private final String nick = "Foo Bar";
 	private final String messageBody = "This is the message body! Wooooooo!";
 
-	private final ReaderFactory rf;
-	private final WriterFactory wf;
-
-	private final KeyPair keyPair;
-	private final Signature sig, sig1;
-	private final MessageDigest dig, dig1;
+	private final ReaderFactory readerFactory;
+	private final WriterFactory writerFactory;
+	private final Signature signature;
+	private final MessageDigest messageDigest, batchDigest;
 	private final KeyParser keyParser;
 	private final Message message;
 
 	public BundleReadWriteTest() throws Exception {
 		super();
+		// Inject the reader and writer factories, since they belong to
+		// a different component
 		Injector i = Guice.createInjector(new SerialModule());
-		rf = i.getInstance(ReaderFactory.class);
-		wf = i.getInstance(WriterFactory.class);
-		keyPair = KeyPairGenerator.getInstance(KEY_PAIR_ALGO).generateKeyPair();
-		sig = Signature.getInstance(SIGNATURE_ALGO);
-		sig1 = Signature.getInstance(SIGNATURE_ALGO);
-		dig = MessageDigest.getInstance(DIGEST_ALGO);
-		dig1 = MessageDigest.getInstance(DIGEST_ALGO);
+		readerFactory = i.getInstance(ReaderFactory.class);
+		writerFactory = i.getInstance(WriterFactory.class);
+		signature = Signature.getInstance(SIGNATURE_ALGO);
+		messageDigest = MessageDigest.getInstance(DIGEST_ALGO);
+		batchDigest = MessageDigest.getInstance(DIGEST_ALGO);
 		final KeyFactory keyFactory = KeyFactory.getInstance(KEY_PAIR_ALGO);
 		keyParser = new KeyParser() {
 			public PublicKey parsePublicKey(byte[] encodedKey)
@@ -93,8 +89,13 @@ public class BundleReadWriteTest extends TestCase {
 				return keyFactory.generatePublic(e);
 			}
 		};
-		assertEquals(dig.getDigestLength(), UniqueId.LENGTH);
-		MessageEncoder messageEncoder = new MessageEncoderImpl(sig, dig, wf);
+		assertEquals(messageDigest.getDigestLength(), UniqueId.LENGTH);
+		assertEquals(batchDigest.getDigestLength(), UniqueId.LENGTH);
+		// Create and encode a test message
+		MessageEncoder messageEncoder = new MessageEncoderImpl(signature,
+				messageDigest, writerFactory);
+		KeyPair keyPair =
+			KeyPairGenerator.getInstance(KEY_PAIR_ALGO).generateKeyPair();
 		message = messageEncoder.encodeMessage(MessageId.NONE, sub, nick,
 				keyPair, messageBody.getBytes("UTF-8"));
 	}
@@ -107,8 +108,8 @@ public class BundleReadWriteTest extends TestCase {
 	@Test
 	public void testWriteBundle() throws Exception {
 		FileOutputStream out = new FileOutputStream(bundle);
-		BundleWriter w = new BundleWriterImpl(out, wf, keyPair.getPrivate(),
-				sig, dig, capacity);
+		BundleWriter w = new BundleWriterImpl(out, writerFactory, batchDigest,
+				capacity);
 		Raw messageRaw = new RawByteArray(message.getBytes());
 
 		w.addHeader(acks, subs, transports);
@@ -125,12 +126,12 @@ public class BundleReadWriteTest extends TestCase {
 		testWriteBundle();
 
 		FileInputStream in = new FileInputStream(bundle);
-		Reader reader = rf.createReader(in);
-		MessageReader messageReader = new MessageReader(keyParser, sig1, dig1);
-		HeaderReader headerReader = new HeaderReader(keyPair.getPublic(), sig,
-				new HeaderFactoryImpl());
-		BatchReader batchReader = new BatchReader(keyPair.getPublic(), sig, dig,
-				messageReader, new BatchFactoryImpl());
+		Reader reader = readerFactory.createReader(in);
+		MessageReader messageReader =
+			new MessageReader(keyParser, signature, messageDigest);
+		HeaderReader headerReader = new HeaderReader(new HeaderFactoryImpl());
+		BatchReader batchReader = new BatchReader(batchDigest, messageReader,
+				new BatchFactoryImpl());
 		BundleReader r = new BundleReaderImpl(reader, headerReader,
 				batchReader);
 
@@ -153,40 +154,6 @@ public class BundleReadWriteTest extends TestCase {
 		r.finish();
 	}
 
-
-	@Test
-	public void testModifyingBundleBreaksSignature() throws Exception {
-
-		testWriteBundle();
-
-		RandomAccessFile f = new RandomAccessFile(bundle, "rw");
-		f.seek(bundle.length() - 100);
-		byte b = f.readByte();
-		f.seek(bundle.length() - 100);
-		f.writeByte(b + 1);
-		f.close();
-
-		FileInputStream in = new FileInputStream(bundle);
-		Reader reader = rf.createReader(in);
-		MessageReader messageReader = new MessageReader(keyParser, sig1, dig1);
-		HeaderReader headerReader = new HeaderReader(keyPair.getPublic(), sig,
-				new HeaderFactoryImpl());
-		BatchReader batchReader = new BatchReader(keyPair.getPublic(), sig, dig,
-				messageReader, new BatchFactoryImpl());
-		BundleReader r = new BundleReaderImpl(reader, headerReader,
-				batchReader);
-
-		Header h = r.getHeader();
-		assertEquals(acks, h.getAcks());
-		assertEquals(subs, h.getSubscriptions());
-		assertEquals(transports, h.getTransports());
-		try {
-			r.getNextBatch();
-			assertTrue(false);
-		} catch(GeneralSecurityException expected) {}
-		r.finish();
-	}
-
 	@After
 	public void tearDown() {
 		TestUtils.deleteTestDirectory(testDir);
diff --git a/test/net/sf/briar/protocol/BundleReaderImplTest.java b/test/net/sf/briar/protocol/BundleReaderImplTest.java
index 228a1a5ec0..8e5947777c 100644
--- a/test/net/sf/briar/protocol/BundleReaderImplTest.java
+++ b/test/net/sf/briar/protocol/BundleReaderImplTest.java
@@ -1,5 +1,270 @@
 package net.sf.briar.protocol;
 
-public class BundleReaderImplTest {
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.security.GeneralSecurityException;
+import java.util.Collections;
+import java.util.Map;
+import java.util.Set;
 
+import junit.framework.TestCase;
+import net.sf.briar.api.protocol.Batch;
+import net.sf.briar.api.protocol.BatchId;
+import net.sf.briar.api.protocol.GroupId;
+import net.sf.briar.api.protocol.Header;
+import net.sf.briar.api.protocol.Message;
+import net.sf.briar.api.protocol.Tags;
+import net.sf.briar.api.serial.FormatException;
+import net.sf.briar.api.serial.ObjectReader;
+import net.sf.briar.api.serial.Reader;
+import net.sf.briar.api.serial.ReaderFactory;
+import net.sf.briar.api.serial.Writer;
+import net.sf.briar.api.serial.WriterFactory;
+import net.sf.briar.serial.SerialModule;
+
+import org.junit.Test;
+
+import com.google.inject.Guice;
+import com.google.inject.Injector;
+
+public class BundleReaderImplTest extends TestCase {
+
+	private final ReaderFactory readerFactory;
+	private final WriterFactory writerFactory;
+
+	public BundleReaderImplTest() {
+		Injector i = Guice.createInjector(new SerialModule());
+		readerFactory = i.getInstance(ReaderFactory.class);
+		writerFactory = i.getInstance(WriterFactory.class);
+	}
+
+	@Test
+	public void testEmptyBundleThrowsFormatException() throws Exception {
+		ByteArrayInputStream in = new ByteArrayInputStream(new byte[] {});
+		Reader r = readerFactory.createReader(in);
+		BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
+				new TestBatchReader());
+
+		try {
+			b.getHeader();
+			assertTrue(false);
+		} catch(FormatException expected) {}
+	}
+
+	@Test
+	public void testReadingBatchBeforeHeaderThrowsIllegalStateException()
+	throws Exception {
+		ByteArrayInputStream in = new ByteArrayInputStream(createValidBundle());
+		Reader r = readerFactory.createReader(in);
+		BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
+				new TestBatchReader());
+
+		try {
+			b.getNextBatch();
+			assertTrue(false);
+		} catch(IllegalStateException expected) {}
+	}
+
+	@Test
+	public void testMissingHeaderThrowsFormatException() throws Exception {
+		// Create a headless bundle
+		ByteArrayOutputStream out = new ByteArrayOutputStream();
+		Writer w = writerFactory.createWriter(out);
+		w.writeListStart();
+		w.writeUserDefinedTag(Tags.BATCH);
+		w.writeList(Collections.emptyList());
+		w.writeListEnd();
+		w.close();
+		byte[] headless = out.toByteArray();
+		// Try to read a header from the headless bundle
+		ByteArrayInputStream in = new ByteArrayInputStream(headless);
+		Reader r = readerFactory.createReader(in);
+		BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
+				new TestBatchReader());
+
+		try {
+			b.getHeader();
+			assertTrue(false);
+		} catch(FormatException expected) {}
+	}
+
+	@Test
+	public void testMissingBatchListThrowsFormatException() throws Exception {
+		// Create a header-only bundle
+		ByteArrayOutputStream out = new ByteArrayOutputStream();
+		Writer w = writerFactory.createWriter(out);
+		w.writeUserDefinedTag(Tags.HEADER);
+		w.writeList(Collections.emptyList()); // Acks
+		w.writeList(Collections.emptyList()); // Subs
+		w.writeMap(Collections.emptyMap()); // Transports
+		w.writeInt64(System.currentTimeMillis()); // Timestamp
+		w.close();
+		byte[] headerOnly = out.toByteArray();
+		// Try to read a header from the header-only bundle
+		ByteArrayInputStream in = new ByteArrayInputStream(headerOnly);
+		final Reader r = readerFactory.createReader(in);
+		BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
+				new TestBatchReader());
+
+		try {
+			b.getHeader();
+			assertTrue(false);
+		} catch(FormatException expected) {}
+	}
+
+	@Test
+	public void testEmptyBatchListIsAcceptable() throws Exception {
+		// Create a bundle with no batches
+		ByteArrayOutputStream out = new ByteArrayOutputStream();
+		Writer w = writerFactory.createWriter(out);
+		w.writeUserDefinedTag(Tags.HEADER);
+		w.writeList(Collections.emptyList()); // Acks
+		w.writeList(Collections.emptyList()); // Subs
+		w.writeMap(Collections.emptyMap()); // Transports
+		w.writeInt64(System.currentTimeMillis()); // Timestamp
+		w.writeListStart();
+		w.writeListEnd();
+		w.close();
+		byte[] batchless = out.toByteArray();
+		// It should be possible to read the header and null
+		ByteArrayInputStream in = new ByteArrayInputStream(batchless);
+		final Reader r = readerFactory.createReader(in);
+		BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
+				new TestBatchReader());
+
+		assertNotNull(b.getHeader());
+		assertNull(b.getNextBatch());
+	}
+
+	@Test
+	public void testValidBundle() throws Exception {
+		// It should be possible to read the header, a batch, and null
+		ByteArrayInputStream in = new ByteArrayInputStream(createValidBundle());
+		final Reader r = readerFactory.createReader(in);
+		BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
+				new TestBatchReader());
+
+		assertNotNull(b.getHeader());
+		assertNotNull(b.getNextBatch());
+		assertNull(b.getNextBatch());
+	}
+
+	@Test
+	public void testReadingBatchAfterNullThrowsIllegalStateException()
+	throws Exception {
+		// Trying to read another batch after null should not succeed
+		ByteArrayInputStream in = new ByteArrayInputStream(createValidBundle());
+		final Reader r = readerFactory.createReader(in);
+		BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
+				new TestBatchReader());
+
+		assertNotNull(b.getHeader());
+		assertNotNull(b.getNextBatch());
+		assertNull(b.getNextBatch());
+		try {
+			b.getNextBatch();
+			assertTrue(false);
+		} catch(IllegalStateException expected) {}
+	}
+
+	@Test
+	public void testReadingHeaderTwiceThrowsIllegalStateException()
+	throws Exception {
+		// Trying to read the header twice should not succeed
+		ByteArrayInputStream in = new ByteArrayInputStream(createValidBundle());
+		final Reader r = readerFactory.createReader(in);
+		BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
+				new TestBatchReader());
+
+		assertNotNull(b.getHeader());
+		try {
+			b.getHeader();
+			assertTrue(false);
+		} catch(IllegalStateException expected) {}
+	}
+
+	@Test
+	public void testReadingHeaderAfterBatchThrowsIllegalStateException()
+	throws Exception {
+		// Trying to read the header after a batch should not succeed
+		ByteArrayInputStream in = new ByteArrayInputStream(createValidBundle());
+		final Reader r = readerFactory.createReader(in);
+		BundleReaderImpl b = new BundleReaderImpl(r, new TestHeaderReader(),
+				new TestBatchReader());
+
+		assertNotNull(b.getHeader());
+		assertNotNull(b.getNextBatch());
+		try {
+			b.getHeader();
+			assertTrue(false);
+		} catch(IllegalStateException expected) {}
+	}
+
+	private byte[] createValidBundle() throws IOException {
+		ByteArrayOutputStream out = new ByteArrayOutputStream();
+		Writer w = writerFactory.createWriter(out);
+		w.writeUserDefinedTag(Tags.HEADER);
+		w.writeList(Collections.emptyList()); // Acks
+		w.writeList(Collections.emptyList()); // Subs
+		w.writeMap(Collections.emptyMap()); // Transports
+		w.writeInt64(System.currentTimeMillis()); // Timestamp
+		w.writeListStart();
+		w.writeUserDefinedTag(Tags.BATCH);
+		w.writeList(Collections.emptyList()); // Messages
+		w.writeListEnd();
+		w.close();
+		return out.toByteArray();
+	}
+
+	private static class TestHeaderReader implements ObjectReader<Header> {
+
+		public Header readObject(Reader r) throws IOException,
+		GeneralSecurityException {
+			r.readList();
+			r.readList();
+			r.readMap();
+			r.readInt64();
+			return new TestHeader();
+		}
+	}
+
+	private static class TestHeader implements Header {
+
+		public Set<BatchId> getAcks() {
+			return null;
+		}
+
+		public Set<GroupId> getSubscriptions() {
+			return null;
+		}
+
+		public Map<String, String> getTransports() {
+			return null;
+		}
+
+		public long getTimestamp() {
+			return 0;
+		}
+	}
+
+	private static class TestBatchReader implements ObjectReader<Batch> {
+
+		public Batch readObject(Reader r) throws IOException,
+		GeneralSecurityException {
+			r.readList();
+			return new TestBatch();
+		}
+	}
+
+	private static class TestBatch implements Batch {
+
+		public BatchId getId() {
+			return null;
+		}
+
+		public Iterable<Message> getMessages() {
+			return null;
+		}
+	}
 }
-- 
GitLab