diff --git a/bramble-core/src/test/java/org/briarproject/bramble/record/RecordReaderImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/record/RecordReaderImplTest.java new file mode 100644 index 0000000000000000000000000000000000000000..26ef89c8ee8e3a190606713c0b8d38cb26359048 --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/record/RecordReaderImplTest.java @@ -0,0 +1,102 @@ +package org.briarproject.bramble.record; + +import org.briarproject.bramble.api.FormatException; +import org.briarproject.bramble.api.record.Record; +import org.briarproject.bramble.api.record.RecordReader; +import org.briarproject.bramble.test.BrambleTestCase; +import org.briarproject.bramble.util.ByteUtils; +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.EOFException; + +import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES; +import static org.briarproject.bramble.api.record.Record.RECORD_HEADER_BYTES; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +public class RecordReaderImplTest extends BrambleTestCase { + + @Test + public void testAcceptsEmptyPayload() throws Exception { + // Version 1, type 2, payload length 0 + byte[] header = new byte[] {1, 2, 0, 0}; + ByteArrayInputStream in = new ByteArrayInputStream(header); + RecordReader reader = new RecordReaderImpl(in); + Record record = reader.readRecord(); + assertEquals(1, record.getProtocolVersion()); + assertEquals(2, record.getRecordType()); + assertArrayEquals(new byte[0], record.getPayload()); + } + + @Test + public void testAcceptsMaxLengthPayload() throws Exception { + byte[] record = + new byte[RECORD_HEADER_BYTES + MAX_RECORD_PAYLOAD_BYTES]; + // Version 1, type 2, payload length MAX_RECORD_PAYLOAD_BYTES + record[0] = 1; + record[1] = 2; + ByteUtils.writeUint16(MAX_RECORD_PAYLOAD_BYTES, record, 2); + ByteArrayInputStream in = new ByteArrayInputStream(record); + RecordReader reader = new RecordReaderImpl(in); + reader.readRecord(); + } + + @Test(expected = FormatException.class) + public void testFormatExceptionIfPayloadLengthIsNegative() + throws Exception { + // Version 1, type 2, payload length -1 + byte[] header = new byte[] {1, 2, (byte) 0xFF, (byte) 0xFF}; + ByteArrayInputStream in = new ByteArrayInputStream(header); + RecordReader reader = new RecordReaderImpl(in); + reader.readRecord(); + } + + @Test(expected = FormatException.class) + public void testFormatExceptionIfPayloadLengthIsTooLarge() + throws Exception { + // Version 1, type 2, payload length MAX_RECORD_PAYLOAD_BYTES + 1 + byte[] header = new byte[] {1, 2, 0, 0}; + ByteUtils.writeUint16(MAX_RECORD_PAYLOAD_BYTES + 1, header, 2); + ByteArrayInputStream in = new ByteArrayInputStream(header); + RecordReader reader = new RecordReaderImpl(in); + reader.readRecord(); + } + + @Test(expected = EOFException.class) + public void testEofExceptionIfProtocolVersionIsMissing() throws Exception { + ByteArrayInputStream in = new ByteArrayInputStream(new byte[0]); + RecordReader reader = new RecordReaderImpl(in); + reader.readRecord(); + } + + @Test(expected = EOFException.class) + public void testEofExceptionIfRecordTypeIsMissing() throws Exception { + ByteArrayInputStream in = new ByteArrayInputStream(new byte[1]); + RecordReader reader = new RecordReaderImpl(in); + reader.readRecord(); + } + + @Test(expected = EOFException.class) + public void testEofExceptionIfPayloadLengthIsMissing() throws Exception { + ByteArrayInputStream in = new ByteArrayInputStream(new byte[2]); + RecordReader reader = new RecordReaderImpl(in); + reader.readRecord(); + } + + @Test(expected = EOFException.class) + public void testEofExceptionIfPayloadLengthIsTruncated() throws Exception { + ByteArrayInputStream in = new ByteArrayInputStream(new byte[3]); + RecordReader reader = new RecordReaderImpl(in); + reader.readRecord(); + } + + @Test(expected = EOFException.class) + public void testEofExceptionIfPayloadIsTruncated() throws Exception { + // Version 0, type 0, payload length 1 + byte[] header = new byte[] {0, 0, 0, 1}; + ByteArrayInputStream in = new ByteArrayInputStream(header); + RecordReader reader = new RecordReaderImpl(in); + reader.readRecord(); + } +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/record/RecordWriterImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/record/RecordWriterImplTest.java new file mode 100644 index 0000000000000000000000000000000000000000..2e5f236befaaeaadda8b810a6fc00335ee1df480 --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/record/RecordWriterImplTest.java @@ -0,0 +1,49 @@ +package org.briarproject.bramble.record; + +import org.briarproject.bramble.api.record.Record; +import org.briarproject.bramble.api.record.RecordWriter; +import org.briarproject.bramble.test.BrambleTestCase; +import org.briarproject.bramble.util.ByteUtils; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; + +import static org.briarproject.bramble.api.record.Record.MAX_RECORD_PAYLOAD_BYTES; +import static org.briarproject.bramble.api.record.Record.RECORD_HEADER_BYTES; +import static org.briarproject.bramble.test.TestUtils.getRandomBytes; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +public class RecordWriterImplTest extends BrambleTestCase { + + @Test + public void testWritesEmptyRecord() throws Exception { + testWritesRecord(0); + } + + @Test + public void testWritesMaxLengthRecord() throws Exception { + testWritesRecord(MAX_RECORD_PAYLOAD_BYTES); + } + + private void testWritesRecord(int payloadLength) throws Exception { + byte protocolVersion = 123; + byte recordType = 45; + byte[] payload = getRandomBytes(payloadLength); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + RecordWriter writer = new RecordWriterImpl(out); + writer.writeRecord(new Record(protocolVersion, recordType, payload)); + writer.flush(); + byte[] written = out.toByteArray(); + + assertEquals(RECORD_HEADER_BYTES + payloadLength, written.length); + assertEquals(protocolVersion, written[0]); + assertEquals(recordType, written[1]); + assertEquals(payloadLength, ByteUtils.readUint16(written, 2)); + byte[] writtenPayload = new byte[payloadLength]; + System.arraycopy(written, RECORD_HEADER_BYTES, writtenPayload, 0, + payloadLength); + assertArrayEquals(payload, writtenPayload); + } +}