From b54984b542f7f8329ca839588439734e2462a330 Mon Sep 17 00:00:00 2001 From: akwizgran <akwizgran@users.sourceforge.net> Date: Wed, 21 Dec 2016 14:39:56 +0000 Subject: [PATCH] Unit tests for RecordReaderImpl. --- .../bramble/sync/RecordReaderImpl.java | 27 +-- .../bramble/sync/RecordReaderImplTest.java | 187 ++++++++++++------ 2 files changed, 136 insertions(+), 78 deletions(-) diff --git a/bramble-core/src/main/java/org/briarproject/bramble/sync/RecordReaderImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/sync/RecordReaderImpl.java index 5fdb18b1bf..e6a80735ab 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/sync/RecordReaderImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/sync/RecordReaderImpl.java @@ -33,7 +33,7 @@ import static org.briarproject.bramble.api.sync.SyncConstants.RECORD_HEADER_LENG @NotNullByDefault class RecordReaderImpl implements RecordReader { - private enum State { BUFFER_EMPTY, BUFFER_FULL, EOF } + private enum State {BUFFER_EMPTY, BUFFER_FULL, EOF} private final MessageFactory messageFactory; private final InputStream in; @@ -64,10 +64,11 @@ class RecordReaderImpl implements RecordReader { } offset += read; } - // Check the protocol version - if (header[0] != PROTOCOL_VERSION) throw new FormatException(); - // Read the payload length + byte version = header[0], type = header[1]; payloadLength = ByteUtils.readUint16(header, 2); + // Check the protocol version + if (version != PROTOCOL_VERSION) throw new FormatException(); + // Check the payload length if (payloadLength > MAX_RECORD_PAYLOAD_LENGTH) throw new FormatException(); // Read the payload @@ -78,21 +79,21 @@ class RecordReaderImpl implements RecordReader { offset += read; } state = State.BUFFER_FULL; - // Return if this is a known record type - if (header[1] == ACK || header[1] == MESSAGE || - header[1] == OFFER || header[1] == REQUEST) { + // Return if this is a known record type, otherwise continue + if (type == ACK || type == MESSAGE || type == OFFER || + type == REQUEST) { return; } } } /** - * The return value indicates whether there's another record available - * or whether we've reached the end of the input stream. - * If a record is available, - * it's been read into the buffer by the time eof() returns, - * so the method that called eof() can access the record from the buffer, - * for example to check its type or extract its payload. + * Returns true if there's another record available or false if we've + * reached the end of the input stream. + * <p> + * If a record is available, it's been read into the buffer by the time + * eof() returns, so the method that called eof() can access the record + * from the buffer, for example to check its type or extract its payload. */ @Override public boolean eof() throws IOException { diff --git a/bramble-core/src/test/java/org/briarproject/bramble/sync/RecordReaderImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/sync/RecordReaderImplTest.java index 8d1fca5154..7b927cf59b 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/sync/RecordReaderImplTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/sync/RecordReaderImplTest.java @@ -2,7 +2,9 @@ package org.briarproject.bramble.sync; import org.briarproject.bramble.api.FormatException; import org.briarproject.bramble.api.UniqueId; -import org.briarproject.bramble.test.BrambleTestCase; +import org.briarproject.bramble.api.sync.Ack; +import org.briarproject.bramble.api.sync.MessageFactory; +import org.briarproject.bramble.test.BrambleMockTestCase; import org.briarproject.bramble.test.TestUtils; import org.briarproject.bramble.util.ByteUtils; import org.junit.Test; @@ -13,17 +15,24 @@ import java.io.ByteArrayOutputStream; import static org.briarproject.bramble.api.sync.RecordTypes.ACK; import static org.briarproject.bramble.api.sync.RecordTypes.OFFER; import static org.briarproject.bramble.api.sync.RecordTypes.REQUEST; +import static org.briarproject.bramble.api.sync.SyncConstants.MAX_MESSAGE_IDS; import static org.briarproject.bramble.api.sync.SyncConstants.MAX_RECORD_PAYLOAD_LENGTH; +import static org.briarproject.bramble.api.sync.SyncConstants.PROTOCOL_VERSION; import static org.briarproject.bramble.api.sync.SyncConstants.RECORD_HEADER_LENGTH; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; -public class RecordReaderImplTest extends BrambleTestCase { +public class RecordReaderImplTest extends BrambleMockTestCase { + + private final MessageFactory messageFactory = + context.mock(MessageFactory.class); @Test(expected = FormatException.class) public void testFormatExceptionIfAckIsTooLarge() throws Exception { byte[] b = createAck(true); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readAck(); } @@ -31,15 +40,15 @@ public class RecordReaderImplTest extends BrambleTestCase { public void testNoFormatExceptionIfAckIsMaximumSize() throws Exception { byte[] b = createAck(false); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readAck(); } @Test(expected = FormatException.class) - public void testEmptyAck() throws Exception { + public void testFormatExceptionIfAckIsEmpty() throws Exception { byte[] b = createEmptyAck(); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readAck(); } @@ -47,7 +56,7 @@ public class RecordReaderImplTest extends BrambleTestCase { public void testFormatExceptionIfOfferIsTooLarge() throws Exception { byte[] b = createOffer(true); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readOffer(); } @@ -55,15 +64,15 @@ public class RecordReaderImplTest extends BrambleTestCase { public void testNoFormatExceptionIfOfferIsMaximumSize() throws Exception { byte[] b = createOffer(false); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readOffer(); } @Test(expected = FormatException.class) - public void testEmptyOffer() throws Exception { + public void testFormatExceptionIfOfferIsEmpty() throws Exception { byte[] b = createEmptyOffer(); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readOffer(); } @@ -71,7 +80,7 @@ public class RecordReaderImplTest extends BrambleTestCase { public void testFormatExceptionIfRequestIsTooLarge() throws Exception { byte[] b = createRequest(true); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readRequest(); } @@ -79,84 +88,132 @@ public class RecordReaderImplTest extends BrambleTestCase { public void testNoFormatExceptionIfRequestIsMaximumSize() throws Exception { byte[] b = createRequest(false); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readRequest(); } @Test(expected = FormatException.class) - public void testEmptyRequest() throws Exception { + public void testFormatExceptionIfRequestIsEmpty() throws Exception { byte[] b = createEmptyRequest(); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readRequest(); } + @Test + public void testEofReturnsTrueWhenAtEndOfStream() throws Exception { + ByteArrayInputStream in = new ByteArrayInputStream(new byte[0]); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); + assertTrue(reader.eof()); + } + + @Test + public void testEofReturnsFalseWhenNotAtEndOfStream() throws Exception { + byte[] b = createAck(false); + ByteArrayInputStream in = new ByteArrayInputStream(b); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); + assertFalse(reader.eof()); + } + + @Test(expected = FormatException.class) + public void testThrowsExceptionIfHeaderIsTooShort() throws Exception { + byte[] b = new byte[RECORD_HEADER_LENGTH - 1]; + b[0] = PROTOCOL_VERSION; + b[1] = ACK; + ByteArrayInputStream in = new ByteArrayInputStream(b); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); + reader.eof(); + } + + @Test(expected = FormatException.class) + public void testThrowsExceptionIfPayloadIsTooShort() throws Exception { + int payloadLength = 123; + byte[] b = new byte[RECORD_HEADER_LENGTH + payloadLength - 1]; + b[0] = PROTOCOL_VERSION; + b[1] = ACK; + ByteUtils.writeUint16(payloadLength, b, 2); + ByteArrayInputStream in = new ByteArrayInputStream(b); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); + reader.eof(); + } + + @Test(expected = FormatException.class) + public void testThrowsExceptionIfProtocolVersionIsUnrecognised() + throws Exception { + byte version = (byte) (PROTOCOL_VERSION + 1); + byte[] b = createRecord(version, ACK, new byte[0]); + ByteArrayInputStream in = new ByteArrayInputStream(b); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); + reader.eof(); + } + + @Test(expected = FormatException.class) + public void testThrowsExceptionIfPayloadIsTooLong() throws Exception { + byte[] payload = new byte[MAX_RECORD_PAYLOAD_LENGTH + 1]; + byte[] b = createRecord(PROTOCOL_VERSION, ACK, payload); + ByteArrayInputStream in = new ByteArrayInputStream(b); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); + reader.eof(); + } + + @Test + public void testSkipsUnrecognisedRecordTypes() throws Exception { + byte[] skip1 = createRecord(PROTOCOL_VERSION, (byte) (REQUEST + 1), + new byte[123]); + byte[] skip2 = createRecord(PROTOCOL_VERSION, (byte) (REQUEST + 2), + new byte[0]); + byte[] ack = createAck(false); + ByteArrayOutputStream input = new ByteArrayOutputStream(); + input.write(skip1); + input.write(skip2); + input.write(ack); + ByteArrayInputStream in = new ByteArrayInputStream(input.toByteArray()); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); + assertTrue(reader.hasAck()); + Ack a = reader.readAck(); + assertEquals(MAX_MESSAGE_IDS, a.getMessageIds().size()); + } + private byte[] createAck(boolean tooBig) throws Exception { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - out.write(new byte[RECORD_HEADER_LENGTH]); - while (out.size() + UniqueId.LENGTH <= RECORD_HEADER_LENGTH - + MAX_RECORD_PAYLOAD_LENGTH) { - out.write(TestUtils.getRandomId()); - } - if (tooBig) out.write(TestUtils.getRandomId()); - assertEquals(tooBig, out.size() > RECORD_HEADER_LENGTH + - MAX_RECORD_PAYLOAD_LENGTH); - byte[] record = out.toByteArray(); - record[1] = ACK; - ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2); - return record; + return createRecord(PROTOCOL_VERSION, ACK, createPayload(tooBig)); } private byte[] createEmptyAck() throws Exception { - byte[] record = new byte[RECORD_HEADER_LENGTH]; - record[1] = ACK; - ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2); - return record; + return createRecord(PROTOCOL_VERSION, ACK, new byte[0]); } private byte[] createOffer(boolean tooBig) throws Exception { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - out.write(new byte[RECORD_HEADER_LENGTH]); - while (out.size() + UniqueId.LENGTH <= RECORD_HEADER_LENGTH - + MAX_RECORD_PAYLOAD_LENGTH) { - out.write(TestUtils.getRandomId()); - } - if (tooBig) out.write(TestUtils.getRandomId()); - assertEquals(tooBig, out.size() > RECORD_HEADER_LENGTH + - MAX_RECORD_PAYLOAD_LENGTH); - byte[] record = out.toByteArray(); - record[1] = OFFER; - ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2); - return record; + return createRecord(PROTOCOL_VERSION, OFFER, createPayload(tooBig)); } private byte[] createEmptyOffer() throws Exception { - byte[] record = new byte[RECORD_HEADER_LENGTH]; - record[1] = OFFER; - ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2); - return record; + return createRecord(PROTOCOL_VERSION, OFFER, new byte[0]); } private byte[] createRequest(boolean tooBig) throws Exception { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - out.write(new byte[RECORD_HEADER_LENGTH]); - while (out.size() + UniqueId.LENGTH <= RECORD_HEADER_LENGTH - + MAX_RECORD_PAYLOAD_LENGTH) { - out.write(TestUtils.getRandomId()); - } - if (tooBig) out.write(TestUtils.getRandomId()); - assertEquals(tooBig, out.size() > RECORD_HEADER_LENGTH + - MAX_RECORD_PAYLOAD_LENGTH); - byte[] record = out.toByteArray(); - record[1] = REQUEST; - ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2); - return record; + return createRecord(PROTOCOL_VERSION, REQUEST, createPayload(tooBig)); } private byte[] createEmptyRequest() throws Exception { - byte[] record = new byte[RECORD_HEADER_LENGTH]; - record[1] = REQUEST; - ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2); - return record; + return createRecord(PROTOCOL_VERSION, REQUEST, new byte[0]); + } + + private byte[] createRecord(byte version, byte type, byte[] payload) { + byte[] b = new byte[RECORD_HEADER_LENGTH + payload.length]; + b[0] = version; + b[1] = type; + ByteUtils.writeUint16(payload.length, b, 2); + System.arraycopy(payload, 0, b, RECORD_HEADER_LENGTH, payload.length); + return b; + } + + private byte[] createPayload(boolean tooBig) throws Exception { + ByteArrayOutputStream payload = new ByteArrayOutputStream(); + while (payload.size() + UniqueId.LENGTH <= MAX_RECORD_PAYLOAD_LENGTH) { + payload.write(TestUtils.getRandomId()); + } + if (tooBig) payload.write(TestUtils.getRandomId()); + assertEquals(tooBig, payload.size() > MAX_RECORD_PAYLOAD_LENGTH); + return payload.toByteArray(); } } -- GitLab