From 1c41ffa7afd559364f5c96806c5c5da32c7da018 Mon Sep 17 00:00:00 2001 From: akwizgran <akwizgran@users.sourceforge.net> Date: Wed, 7 Dec 2011 13:32:17 +0000 Subject: [PATCH] Don't accept empty acks, batches or offers. --- .../net/sf/briar/protocol/AckReader.java | 4 +- .../net/sf/briar/protocol/BatchReader.java | 2 + .../net/sf/briar/protocol/OfferReader.java | 4 +- test/build.xml | 1 + test/net/sf/briar/protocol/AckReaderTest.java | 31 ++--- .../sf/briar/protocol/BatchReaderTest.java | 12 +- .../sf/briar/protocol/OfferReaderTest.java | 121 ++++++++++++++++++ 7 files changed, 143 insertions(+), 32 deletions(-) create mode 100644 test/net/sf/briar/protocol/OfferReaderTest.java diff --git a/components/net/sf/briar/protocol/AckReader.java b/components/net/sf/briar/protocol/AckReader.java index f2c4a899f4..977eaa8dab 100644 --- a/components/net/sf/briar/protocol/AckReader.java +++ b/components/net/sf/briar/protocol/AckReader.java @@ -2,7 +2,6 @@ package net.sf.briar.protocol; import java.io.IOException; import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; import java.util.List; @@ -35,9 +34,10 @@ class AckReader implements ObjectReader<Ack> { r.addConsumer(counting); r.readStructId(Types.ACK); r.setMaxBytesLength(UniqueId.LENGTH); - Collection<Bytes> raw = r.readList(Bytes.class); + List<Bytes> raw = r.readList(Bytes.class); r.resetMaxBytesLength(); r.removeConsumer(counting); + if(raw.isEmpty()) throw new FormatException(); // Convert the byte arrays to batch IDs List<BatchId> batches = new ArrayList<BatchId>(); for(Bytes b : raw) { diff --git a/components/net/sf/briar/protocol/BatchReader.java b/components/net/sf/briar/protocol/BatchReader.java index 1f67b3d1b0..de7f5e31e4 100644 --- a/components/net/sf/briar/protocol/BatchReader.java +++ b/components/net/sf/briar/protocol/BatchReader.java @@ -3,6 +3,7 @@ package net.sf.briar.protocol; import java.io.IOException; import java.util.List; +import net.sf.briar.api.FormatException; import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.MessageDigest; import net.sf.briar.api.protocol.BatchId; @@ -44,6 +45,7 @@ class BatchReader implements ObjectReader<UnverifiedBatch> { r.removeObjectReader(Types.MESSAGE); r.removeConsumer(digesting); r.removeConsumer(counting); + if(messages.isEmpty()) throw new FormatException(); // Build and return the batch BatchId id = new BatchId(messageDigest.digest()); return batchFactory.createUnverifiedBatch(id, messages); diff --git a/components/net/sf/briar/protocol/OfferReader.java b/components/net/sf/briar/protocol/OfferReader.java index f87247dda4..d7e7dcd878 100644 --- a/components/net/sf/briar/protocol/OfferReader.java +++ b/components/net/sf/briar/protocol/OfferReader.java @@ -2,7 +2,6 @@ package net.sf.briar.protocol; import java.io.IOException; import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; import java.util.List; @@ -35,9 +34,10 @@ class OfferReader implements ObjectReader<Offer> { r.addConsumer(counting); r.readStructId(Types.OFFER); r.setMaxBytesLength(UniqueId.LENGTH); - Collection<Bytes> raw = r.readList(Bytes.class); + List<Bytes> raw = r.readList(Bytes.class); r.resetMaxBytesLength(); r.removeConsumer(counting); + if(raw.isEmpty()) throw new FormatException(); // Convert the byte arrays to message IDs List<MessageId> messages = new ArrayList<MessageId>(); for(Bytes b : raw) { diff --git a/test/build.xml b/test/build.xml index f40c9ce045..d865f5bd5a 100644 --- a/test/build.xml +++ b/test/build.xml @@ -39,6 +39,7 @@ <test name='net.sf.briar.protocol.BatchReaderTest'/> <test name='net.sf.briar.protocol.ConstantsTest'/> <test name='net.sf.briar.protocol.ConsumersTest'/> + <test name='net.sf.briar.protocol.OfferReaderTest'/> <test name='net.sf.briar.protocol.ProtocolReadWriteTest'/> <test name='net.sf.briar.protocol.ProtocolWriterImplTest'/> <test name='net.sf.briar.protocol.RequestReaderTest'/> diff --git a/test/net/sf/briar/protocol/AckReaderTest.java b/test/net/sf/briar/protocol/AckReaderTest.java index 3ff8ab8e0d..878718082f 100644 --- a/test/net/sf/briar/protocol/AckReaderTest.java +++ b/test/net/sf/briar/protocol/AckReaderTest.java @@ -3,19 +3,17 @@ package net.sf.briar.protocol; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.util.Collection; -import java.util.Collections; -import java.util.Random; import junit.framework.TestCase; +import net.sf.briar.TestUtils; import net.sf.briar.api.FormatException; import net.sf.briar.api.protocol.Ack; -import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Types; -import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.ReaderFactory; +import net.sf.briar.api.serial.SerialComponent; import net.sf.briar.api.serial.Writer; import net.sf.briar.api.serial.WriterFactory; import net.sf.briar.serial.SerialModule; @@ -29,6 +27,7 @@ import com.google.inject.Injector; public class AckReaderTest extends TestCase { + private final SerialComponent serial; private final ReaderFactory readerFactory; private final WriterFactory writerFactory; private final Mockery context; @@ -36,6 +35,7 @@ public class AckReaderTest extends TestCase { public AckReaderTest() throws Exception { super(); Injector i = Guice.createInjector(new SerialModule()); + serial = i.getInstance(SerialComponent.class); readerFactory = i.getInstance(ReaderFactory.class); writerFactory = i.getInstance(WriterFactory.class); context = new Mockery(); @@ -82,19 +82,16 @@ public class AckReaderTest extends TestCase { public void testEmptyAck() throws Exception { final PacketFactory packetFactory = context.mock(PacketFactory.class); AckReader ackReader = new AckReader(packetFactory); - final Ack ack = context.mock(Ack.class); - context.checking(new Expectations() {{ - oneOf(packetFactory).createAck( - with(Collections.<BatchId>emptyList())); - will(returnValue(ack)); - }}); byte[] b = createEmptyAck(); ByteArrayInputStream in = new ByteArrayInputStream(b); Reader reader = readerFactory.createReader(in); reader.addObjectReader(Types.ACK, ackReader); - assertEquals(ack, reader.readStruct(Types.ACK, Ack.class)); + try { + reader.readStruct(Types.ACK, Ack.class); + fail(); + } catch(FormatException expected) {} context.assertIsSatisfied(); } @@ -103,17 +100,11 @@ public class AckReaderTest extends TestCase { Writer w = writerFactory.createWriter(out); w.writeStructId(Types.ACK); w.writeListStart(); - byte[] b = new byte[UniqueId.LENGTH]; - Random random = new Random(); - while(out.size() + BatchId.LENGTH + 3 + while(out.size() + serial.getSerialisedUniqueIdLength() < ProtocolConstants.MAX_PACKET_LENGTH) { - random.nextBytes(b); - w.writeBytes(b); - } - if(tooBig) { - random.nextBytes(b); - w.writeBytes(b); + w.writeBytes(TestUtils.getRandomId()); } + if(tooBig) w.writeBytes(TestUtils.getRandomId()); w.writeListEnd(); assertEquals(tooBig, out.size() > ProtocolConstants.MAX_PACKET_LENGTH); return out.toByteArray(); diff --git a/test/net/sf/briar/protocol/BatchReaderTest.java b/test/net/sf/briar/protocol/BatchReaderTest.java index 8f8898fd5d..6ff7bfecee 100644 --- a/test/net/sf/briar/protocol/BatchReaderTest.java +++ b/test/net/sf/briar/protocol/BatchReaderTest.java @@ -127,20 +127,16 @@ public class BatchReaderTest extends TestCase { context.mock(UnverifiedBatchFactory.class); BatchReader batchReader = new BatchReader(crypto, messageReader, batchFactory); - final UnverifiedBatch batch = context.mock(UnverifiedBatch.class); - context.checking(new Expectations() {{ - oneOf(batchFactory).createUnverifiedBatch(with(any(BatchId.class)), - with(Collections.<UnverifiedMessage>emptyList())); - will(returnValue(batch)); - }}); byte[] b = createEmptyBatch(); ByteArrayInputStream in = new ByteArrayInputStream(b); Reader reader = readerFactory.createReader(in); reader.addObjectReader(Types.BATCH, batchReader); - assertEquals(batch, reader.readStruct(Types.BATCH, - UnverifiedBatch.class)); + try { + reader.readStruct(Types.BATCH, UnverifiedBatch.class); + fail(); + } catch(FormatException expected) {} context.assertIsSatisfied(); } diff --git a/test/net/sf/briar/protocol/OfferReaderTest.java b/test/net/sf/briar/protocol/OfferReaderTest.java new file mode 100644 index 0000000000..6c25e129d3 --- /dev/null +++ b/test/net/sf/briar/protocol/OfferReaderTest.java @@ -0,0 +1,121 @@ +package net.sf.briar.protocol; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.Collection; + +import junit.framework.TestCase; +import net.sf.briar.TestUtils; +import net.sf.briar.api.FormatException; +import net.sf.briar.api.protocol.Offer; +import net.sf.briar.api.protocol.PacketFactory; +import net.sf.briar.api.protocol.ProtocolConstants; +import net.sf.briar.api.protocol.Types; +import net.sf.briar.api.serial.Reader; +import net.sf.briar.api.serial.ReaderFactory; +import net.sf.briar.api.serial.SerialComponent; +import net.sf.briar.api.serial.Writer; +import net.sf.briar.api.serial.WriterFactory; +import net.sf.briar.serial.SerialModule; + +import org.jmock.Expectations; +import org.jmock.Mockery; +import org.junit.Test; + +import com.google.inject.Guice; +import com.google.inject.Injector; + +public class OfferReaderTest extends TestCase { + + private final SerialComponent serial; + private final ReaderFactory readerFactory; + private final WriterFactory writerFactory; + private final Mockery context; + + public OfferReaderTest() throws Exception { + super(); + Injector i = Guice.createInjector(new SerialModule()); + serial = i.getInstance(SerialComponent.class); + readerFactory = i.getInstance(ReaderFactory.class); + writerFactory = i.getInstance(WriterFactory.class); + context = new Mockery(); + } + + @Test + public void testFormatExceptionIfOfferIsTooLarge() throws Exception { + PacketFactory packetFactory = context.mock(PacketFactory.class); + OfferReader offerReader = new OfferReader(packetFactory); + + byte[] b = createOffer(true); + ByteArrayInputStream in = new ByteArrayInputStream(b); + Reader reader = readerFactory.createReader(in); + reader.addObjectReader(Types.OFFER, offerReader); + + try { + reader.readStruct(Types.OFFER, Offer.class); + fail(); + } catch(FormatException expected) {} + context.assertIsSatisfied(); + } + + @Test + @SuppressWarnings("unchecked") + public void testNoFormatExceptionIfOfferIsMaximumSize() throws Exception { + final PacketFactory packetFactory = context.mock(PacketFactory.class); + OfferReader offerReader = new OfferReader(packetFactory); + final Offer offer = context.mock(Offer.class); + context.checking(new Expectations() {{ + oneOf(packetFactory).createOffer(with(any(Collection.class))); + will(returnValue(offer)); + }}); + + byte[] b = createOffer(false); + ByteArrayInputStream in = new ByteArrayInputStream(b); + Reader reader = readerFactory.createReader(in); + reader.addObjectReader(Types.OFFER, offerReader); + + assertEquals(offer, reader.readStruct(Types.OFFER, Offer.class)); + context.assertIsSatisfied(); + } + + @Test + public void testEmptyOffer() throws Exception { + final PacketFactory packetFactory = context.mock(PacketFactory.class); + OfferReader offerReader = new OfferReader(packetFactory); + + byte[] b = createEmptyOffer(); + ByteArrayInputStream in = new ByteArrayInputStream(b); + Reader reader = readerFactory.createReader(in); + reader.addObjectReader(Types.OFFER, offerReader); + + try { + reader.readStruct(Types.OFFER, Offer.class); + fail(); + } catch(FormatException expected) {} + context.assertIsSatisfied(); + } + + private byte[] createOffer(boolean tooBig) throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + Writer w = writerFactory.createWriter(out); + w.writeStructId(Types.OFFER); + w.writeListStart(); + while(out.size() + serial.getSerialisedUniqueIdLength() + < ProtocolConstants.MAX_PACKET_LENGTH) { + w.writeBytes(TestUtils.getRandomId()); + } + if(tooBig) w.writeBytes(TestUtils.getRandomId()); + w.writeListEnd(); + assertEquals(tooBig, out.size() > ProtocolConstants.MAX_PACKET_LENGTH); + return out.toByteArray(); + } + + private byte[] createEmptyOffer() throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + Writer w = writerFactory.createWriter(out); + w.writeStructId(Types.OFFER); + w.writeListStart(); + w.writeListEnd(); + return out.toByteArray(); + } +} -- GitLab