From 1c41ffa7afd559364f5c96806c5c5da32c7da018 Mon Sep 17 00:00:00 2001
From: akwizgran <akwizgran@users.sourceforge.net>
Date: Wed, 7 Dec 2011 13:32:17 +0000
Subject: [PATCH] Don't accept empty acks, batches or offers.

---
 .../net/sf/briar/protocol/AckReader.java      |   4 +-
 .../net/sf/briar/protocol/BatchReader.java    |   2 +
 .../net/sf/briar/protocol/OfferReader.java    |   4 +-
 test/build.xml                                |   1 +
 test/net/sf/briar/protocol/AckReaderTest.java |  31 ++---
 .../sf/briar/protocol/BatchReaderTest.java    |  12 +-
 .../sf/briar/protocol/OfferReaderTest.java    | 121 ++++++++++++++++++
 7 files changed, 143 insertions(+), 32 deletions(-)
 create mode 100644 test/net/sf/briar/protocol/OfferReaderTest.java

diff --git a/components/net/sf/briar/protocol/AckReader.java b/components/net/sf/briar/protocol/AckReader.java
index f2c4a899f4..977eaa8dab 100644
--- a/components/net/sf/briar/protocol/AckReader.java
+++ b/components/net/sf/briar/protocol/AckReader.java
@@ -2,7 +2,6 @@ package net.sf.briar.protocol;
 
 import java.io.IOException;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 
@@ -35,9 +34,10 @@ class AckReader implements ObjectReader<Ack> {
 		r.addConsumer(counting);
 		r.readStructId(Types.ACK);
 		r.setMaxBytesLength(UniqueId.LENGTH);
-		Collection<Bytes> raw = r.readList(Bytes.class);
+		List<Bytes> raw = r.readList(Bytes.class);
 		r.resetMaxBytesLength();
 		r.removeConsumer(counting);
+		if(raw.isEmpty()) throw new FormatException();
 		// Convert the byte arrays to batch IDs
 		List<BatchId> batches = new ArrayList<BatchId>();
 		for(Bytes b : raw) {
diff --git a/components/net/sf/briar/protocol/BatchReader.java b/components/net/sf/briar/protocol/BatchReader.java
index 1f67b3d1b0..de7f5e31e4 100644
--- a/components/net/sf/briar/protocol/BatchReader.java
+++ b/components/net/sf/briar/protocol/BatchReader.java
@@ -3,6 +3,7 @@ package net.sf.briar.protocol;
 import java.io.IOException;
 import java.util.List;
 
+import net.sf.briar.api.FormatException;
 import net.sf.briar.api.crypto.CryptoComponent;
 import net.sf.briar.api.crypto.MessageDigest;
 import net.sf.briar.api.protocol.BatchId;
@@ -44,6 +45,7 @@ class BatchReader implements ObjectReader<UnverifiedBatch> {
 		r.removeObjectReader(Types.MESSAGE);
 		r.removeConsumer(digesting);
 		r.removeConsumer(counting);
+		if(messages.isEmpty()) throw new FormatException();
 		// Build and return the batch
 		BatchId id = new BatchId(messageDigest.digest());
 		return batchFactory.createUnverifiedBatch(id, messages);
diff --git a/components/net/sf/briar/protocol/OfferReader.java b/components/net/sf/briar/protocol/OfferReader.java
index f87247dda4..d7e7dcd878 100644
--- a/components/net/sf/briar/protocol/OfferReader.java
+++ b/components/net/sf/briar/protocol/OfferReader.java
@@ -2,7 +2,6 @@ package net.sf.briar.protocol;
 
 import java.io.IOException;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 
@@ -35,9 +34,10 @@ class OfferReader implements ObjectReader<Offer> {
 		r.addConsumer(counting);
 		r.readStructId(Types.OFFER);
 		r.setMaxBytesLength(UniqueId.LENGTH);
-		Collection<Bytes> raw = r.readList(Bytes.class);
+		List<Bytes> raw = r.readList(Bytes.class);
 		r.resetMaxBytesLength();
 		r.removeConsumer(counting);
+		if(raw.isEmpty()) throw new FormatException();
 		// Convert the byte arrays to message IDs
 		List<MessageId> messages = new ArrayList<MessageId>();
 		for(Bytes b : raw) {
diff --git a/test/build.xml b/test/build.xml
index f40c9ce045..d865f5bd5a 100644
--- a/test/build.xml
+++ b/test/build.xml
@@ -39,6 +39,7 @@
 			<test name='net.sf.briar.protocol.BatchReaderTest'/>
 			<test name='net.sf.briar.protocol.ConstantsTest'/>
 			<test name='net.sf.briar.protocol.ConsumersTest'/>
+			<test name='net.sf.briar.protocol.OfferReaderTest'/>
 			<test name='net.sf.briar.protocol.ProtocolReadWriteTest'/>
 			<test name='net.sf.briar.protocol.ProtocolWriterImplTest'/>
 			<test name='net.sf.briar.protocol.RequestReaderTest'/>
diff --git a/test/net/sf/briar/protocol/AckReaderTest.java b/test/net/sf/briar/protocol/AckReaderTest.java
index 3ff8ab8e0d..878718082f 100644
--- a/test/net/sf/briar/protocol/AckReaderTest.java
+++ b/test/net/sf/briar/protocol/AckReaderTest.java
@@ -3,19 +3,17 @@ package net.sf.briar.protocol;
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.util.Collection;
-import java.util.Collections;
-import java.util.Random;
 
 import junit.framework.TestCase;
+import net.sf.briar.TestUtils;
 import net.sf.briar.api.FormatException;
 import net.sf.briar.api.protocol.Ack;
-import net.sf.briar.api.protocol.BatchId;
 import net.sf.briar.api.protocol.PacketFactory;
 import net.sf.briar.api.protocol.ProtocolConstants;
 import net.sf.briar.api.protocol.Types;
-import net.sf.briar.api.protocol.UniqueId;
 import net.sf.briar.api.serial.Reader;
 import net.sf.briar.api.serial.ReaderFactory;
+import net.sf.briar.api.serial.SerialComponent;
 import net.sf.briar.api.serial.Writer;
 import net.sf.briar.api.serial.WriterFactory;
 import net.sf.briar.serial.SerialModule;
@@ -29,6 +27,7 @@ import com.google.inject.Injector;
 
 public class AckReaderTest extends TestCase {
 
+	private final SerialComponent serial;
 	private final ReaderFactory readerFactory;
 	private final WriterFactory writerFactory;
 	private final Mockery context;
@@ -36,6 +35,7 @@ public class AckReaderTest extends TestCase {
 	public AckReaderTest() throws Exception {
 		super();
 		Injector i = Guice.createInjector(new SerialModule());
+		serial = i.getInstance(SerialComponent.class);
 		readerFactory = i.getInstance(ReaderFactory.class);
 		writerFactory = i.getInstance(WriterFactory.class);
 		context = new Mockery();
@@ -82,19 +82,16 @@ public class AckReaderTest extends TestCase {
 	public void testEmptyAck() throws Exception {
 		final PacketFactory packetFactory = context.mock(PacketFactory.class);
 		AckReader ackReader = new AckReader(packetFactory);
-		final Ack ack = context.mock(Ack.class);
-		context.checking(new Expectations() {{
-			oneOf(packetFactory).createAck(
-					with(Collections.<BatchId>emptyList()));
-			will(returnValue(ack));
-		}});
 
 		byte[] b = createEmptyAck();
 		ByteArrayInputStream in = new ByteArrayInputStream(b);
 		Reader reader = readerFactory.createReader(in);
 		reader.addObjectReader(Types.ACK, ackReader);
 
-		assertEquals(ack, reader.readStruct(Types.ACK, Ack.class));
+		try {
+			reader.readStruct(Types.ACK, Ack.class);
+			fail();
+		} catch(FormatException expected) {}
 		context.assertIsSatisfied();
 	}
 
@@ -103,17 +100,11 @@ public class AckReaderTest extends TestCase {
 		Writer w = writerFactory.createWriter(out);
 		w.writeStructId(Types.ACK);
 		w.writeListStart();
-		byte[] b = new byte[UniqueId.LENGTH];
-		Random random = new Random();
-		while(out.size() + BatchId.LENGTH + 3
+		while(out.size() + serial.getSerialisedUniqueIdLength()
 				< ProtocolConstants.MAX_PACKET_LENGTH) {
-			random.nextBytes(b);
-			w.writeBytes(b);
-		}
-		if(tooBig) {
-			random.nextBytes(b);
-			w.writeBytes(b);
+			w.writeBytes(TestUtils.getRandomId());
 		}
+		if(tooBig) w.writeBytes(TestUtils.getRandomId());
 		w.writeListEnd();
 		assertEquals(tooBig, out.size() > ProtocolConstants.MAX_PACKET_LENGTH);
 		return out.toByteArray();
diff --git a/test/net/sf/briar/protocol/BatchReaderTest.java b/test/net/sf/briar/protocol/BatchReaderTest.java
index 8f8898fd5d..6ff7bfecee 100644
--- a/test/net/sf/briar/protocol/BatchReaderTest.java
+++ b/test/net/sf/briar/protocol/BatchReaderTest.java
@@ -127,20 +127,16 @@ public class BatchReaderTest extends TestCase {
 			context.mock(UnverifiedBatchFactory.class);
 		BatchReader batchReader = new BatchReader(crypto, messageReader,
 				batchFactory);
-		final UnverifiedBatch batch = context.mock(UnverifiedBatch.class);
-		context.checking(new Expectations() {{
-			oneOf(batchFactory).createUnverifiedBatch(with(any(BatchId.class)),
-					with(Collections.<UnverifiedMessage>emptyList()));
-			will(returnValue(batch));
-		}});
 
 		byte[] b = createEmptyBatch();
 		ByteArrayInputStream in = new ByteArrayInputStream(b);
 		Reader reader = readerFactory.createReader(in);
 		reader.addObjectReader(Types.BATCH, batchReader);
 
-		assertEquals(batch, reader.readStruct(Types.BATCH,
-				UnverifiedBatch.class));
+		try {
+			reader.readStruct(Types.BATCH, UnverifiedBatch.class);
+			fail();
+		} catch(FormatException expected) {}
 		context.assertIsSatisfied();
 	}
 
diff --git a/test/net/sf/briar/protocol/OfferReaderTest.java b/test/net/sf/briar/protocol/OfferReaderTest.java
new file mode 100644
index 0000000000..6c25e129d3
--- /dev/null
+++ b/test/net/sf/briar/protocol/OfferReaderTest.java
@@ -0,0 +1,121 @@
+package net.sf.briar.protocol;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.util.Collection;
+
+import junit.framework.TestCase;
+import net.sf.briar.TestUtils;
+import net.sf.briar.api.FormatException;
+import net.sf.briar.api.protocol.Offer;
+import net.sf.briar.api.protocol.PacketFactory;
+import net.sf.briar.api.protocol.ProtocolConstants;
+import net.sf.briar.api.protocol.Types;
+import net.sf.briar.api.serial.Reader;
+import net.sf.briar.api.serial.ReaderFactory;
+import net.sf.briar.api.serial.SerialComponent;
+import net.sf.briar.api.serial.Writer;
+import net.sf.briar.api.serial.WriterFactory;
+import net.sf.briar.serial.SerialModule;
+
+import org.jmock.Expectations;
+import org.jmock.Mockery;
+import org.junit.Test;
+
+import com.google.inject.Guice;
+import com.google.inject.Injector;
+
+public class OfferReaderTest extends TestCase {
+
+	private final SerialComponent serial;
+	private final ReaderFactory readerFactory;
+	private final WriterFactory writerFactory;
+	private final Mockery context;
+
+	public OfferReaderTest() throws Exception {
+		super();
+		Injector i = Guice.createInjector(new SerialModule());
+		serial = i.getInstance(SerialComponent.class);
+		readerFactory = i.getInstance(ReaderFactory.class);
+		writerFactory = i.getInstance(WriterFactory.class);
+		context = new Mockery();
+	}
+
+	@Test
+	public void testFormatExceptionIfOfferIsTooLarge() throws Exception {
+		PacketFactory packetFactory = context.mock(PacketFactory.class);
+		OfferReader offerReader = new OfferReader(packetFactory);
+
+		byte[] b = createOffer(true);
+		ByteArrayInputStream in = new ByteArrayInputStream(b);
+		Reader reader = readerFactory.createReader(in);
+		reader.addObjectReader(Types.OFFER, offerReader);
+
+		try {
+			reader.readStruct(Types.OFFER, Offer.class);
+			fail();
+		} catch(FormatException expected) {}
+		context.assertIsSatisfied();
+	}
+
+	@Test
+	@SuppressWarnings("unchecked")
+	public void testNoFormatExceptionIfOfferIsMaximumSize() throws Exception {
+		final PacketFactory packetFactory = context.mock(PacketFactory.class);
+		OfferReader offerReader = new OfferReader(packetFactory);
+		final Offer offer = context.mock(Offer.class);
+		context.checking(new Expectations() {{
+			oneOf(packetFactory).createOffer(with(any(Collection.class)));
+			will(returnValue(offer));
+		}});
+
+		byte[] b = createOffer(false);
+		ByteArrayInputStream in = new ByteArrayInputStream(b);
+		Reader reader = readerFactory.createReader(in);
+		reader.addObjectReader(Types.OFFER, offerReader);
+
+		assertEquals(offer, reader.readStruct(Types.OFFER, Offer.class));
+		context.assertIsSatisfied();
+	}
+
+	@Test
+	public void testEmptyOffer() throws Exception {
+		final PacketFactory packetFactory = context.mock(PacketFactory.class);
+		OfferReader offerReader = new OfferReader(packetFactory);
+
+		byte[] b = createEmptyOffer();
+		ByteArrayInputStream in = new ByteArrayInputStream(b);
+		Reader reader = readerFactory.createReader(in);
+		reader.addObjectReader(Types.OFFER, offerReader);
+
+		try {
+			reader.readStruct(Types.OFFER, Offer.class);
+			fail();
+		} catch(FormatException expected) {}
+		context.assertIsSatisfied();
+	}
+
+	private byte[] createOffer(boolean tooBig) throws Exception {
+		ByteArrayOutputStream out = new ByteArrayOutputStream();
+		Writer w = writerFactory.createWriter(out);
+		w.writeStructId(Types.OFFER);
+		w.writeListStart();
+		while(out.size() + serial.getSerialisedUniqueIdLength()
+				< ProtocolConstants.MAX_PACKET_LENGTH) {
+			w.writeBytes(TestUtils.getRandomId());
+		}
+		if(tooBig) w.writeBytes(TestUtils.getRandomId());
+		w.writeListEnd();
+		assertEquals(tooBig, out.size() > ProtocolConstants.MAX_PACKET_LENGTH);
+		return out.toByteArray();
+	}
+
+	private byte[] createEmptyOffer() throws Exception {
+		ByteArrayOutputStream out = new ByteArrayOutputStream();
+		Writer w = writerFactory.createWriter(out);
+		w.writeStructId(Types.OFFER);
+		w.writeListStart();
+		w.writeListEnd();
+		return out.toByteArray();
+	}
+}
-- 
GitLab