Unit tests for RecordReaderImpl.

parent 2390f767
......@@ -33,7 +33,7 @@ import static org.briarproject.bramble.api.sync.SyncConstants.RECORD_HEADER_LENG
@NotNullByDefault
class RecordReaderImpl implements RecordReader {
private enum State { BUFFER_EMPTY, BUFFER_FULL, EOF }
private enum State {BUFFER_EMPTY, BUFFER_FULL, EOF}
private final MessageFactory messageFactory;
private final InputStream in;
......@@ -64,10 +64,11 @@ class RecordReaderImpl implements RecordReader {
}
offset += read;
}
// Check the protocol version
if (header[0] != PROTOCOL_VERSION) throw new FormatException();
// Read the payload length
byte version = header[0], type = header[1];
payloadLength = ByteUtils.readUint16(header, 2);
// Check the protocol version
if (version != PROTOCOL_VERSION) throw new FormatException();
// Check the payload length
if (payloadLength > MAX_RECORD_PAYLOAD_LENGTH)
throw new FormatException();
// Read the payload
......@@ -78,21 +79,21 @@ class RecordReaderImpl implements RecordReader {
offset += read;
}
state = State.BUFFER_FULL;
// Return if this is a known record type
if (header[1] == ACK || header[1] == MESSAGE ||
header[1] == OFFER || header[1] == REQUEST) {
// Return if this is a known record type, otherwise continue
if (type == ACK || type == MESSAGE || type == OFFER ||
type == REQUEST) {
return;
}
}
}
/**
* The return value indicates whether there's another record available
* or whether we've reached the end of the input stream.
* If a record is available,
* it's been read into the buffer by the time eof() returns,
* so the method that called eof() can access the record from the buffer,
* for example to check its type or extract its payload.
* Returns true if there's another record available or false if we've
* reached the end of the input stream.
* <p>
* If a record is available, it's been read into the buffer by the time
* eof() returns, so the method that called eof() can access the record
* from the buffer, for example to check its type or extract its payload.
*/
@Override
public boolean eof() throws IOException {
......
......@@ -2,7 +2,9 @@ package org.briarproject.bramble.sync;
import org.briarproject.bramble.api.FormatException;
import org.briarproject.bramble.api.UniqueId;
import org.briarproject.bramble.test.BrambleTestCase;
import org.briarproject.bramble.api.sync.Ack;
import org.briarproject.bramble.api.sync.MessageFactory;
import org.briarproject.bramble.test.BrambleMockTestCase;
import org.briarproject.bramble.test.TestUtils;
import org.briarproject.bramble.util.ByteUtils;
import org.junit.Test;
......@@ -13,17 +15,24 @@ import java.io.ByteArrayOutputStream;
import static org.briarproject.bramble.api.sync.RecordTypes.ACK;
import static org.briarproject.bramble.api.sync.RecordTypes.OFFER;
import static org.briarproject.bramble.api.sync.RecordTypes.REQUEST;
import static org.briarproject.bramble.api.sync.SyncConstants.MAX_MESSAGE_IDS;
import static org.briarproject.bramble.api.sync.SyncConstants.MAX_RECORD_PAYLOAD_LENGTH;
import static org.briarproject.bramble.api.sync.SyncConstants.PROTOCOL_VERSION;
import static org.briarproject.bramble.api.sync.SyncConstants.RECORD_HEADER_LENGTH;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
public class RecordReaderImplTest extends BrambleTestCase {
public class RecordReaderImplTest extends BrambleMockTestCase {
private final MessageFactory messageFactory =
context.mock(MessageFactory.class);
@Test(expected = FormatException.class)
public void testFormatExceptionIfAckIsTooLarge() throws Exception {
byte[] b = createAck(true);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(null, in);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readAck();
}
......@@ -31,15 +40,15 @@ public class RecordReaderImplTest extends BrambleTestCase {
public void testNoFormatExceptionIfAckIsMaximumSize() throws Exception {
byte[] b = createAck(false);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(null, in);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readAck();
}
@Test(expected = FormatException.class)
public void testEmptyAck() throws Exception {
public void testFormatExceptionIfAckIsEmpty() throws Exception {
byte[] b = createEmptyAck();
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(null, in);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readAck();
}
......@@ -47,7 +56,7 @@ public class RecordReaderImplTest extends BrambleTestCase {
public void testFormatExceptionIfOfferIsTooLarge() throws Exception {
byte[] b = createOffer(true);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(null, in);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readOffer();
}
......@@ -55,15 +64,15 @@ public class RecordReaderImplTest extends BrambleTestCase {
public void testNoFormatExceptionIfOfferIsMaximumSize() throws Exception {
byte[] b = createOffer(false);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(null, in);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readOffer();
}
@Test(expected = FormatException.class)
public void testEmptyOffer() throws Exception {
public void testFormatExceptionIfOfferIsEmpty() throws Exception {
byte[] b = createEmptyOffer();
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(null, in);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readOffer();
}
......@@ -71,7 +80,7 @@ public class RecordReaderImplTest extends BrambleTestCase {
public void testFormatExceptionIfRequestIsTooLarge() throws Exception {
byte[] b = createRequest(true);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(null, in);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readRequest();
}
......@@ -79,84 +88,132 @@ public class RecordReaderImplTest extends BrambleTestCase {
public void testNoFormatExceptionIfRequestIsMaximumSize() throws Exception {
byte[] b = createRequest(false);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(null, in);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readRequest();
}
@Test(expected = FormatException.class)
public void testEmptyRequest() throws Exception {
public void testFormatExceptionIfRequestIsEmpty() throws Exception {
byte[] b = createEmptyRequest();
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(null, in);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.readRequest();
}
@Test
public void testEofReturnsTrueWhenAtEndOfStream() throws Exception {
ByteArrayInputStream in = new ByteArrayInputStream(new byte[0]);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
assertTrue(reader.eof());
}
@Test
public void testEofReturnsFalseWhenNotAtEndOfStream() throws Exception {
byte[] b = createAck(false);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
assertFalse(reader.eof());
}
@Test(expected = FormatException.class)
public void testThrowsExceptionIfHeaderIsTooShort() throws Exception {
byte[] b = new byte[RECORD_HEADER_LENGTH - 1];
b[0] = PROTOCOL_VERSION;
b[1] = ACK;
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.eof();
}
@Test(expected = FormatException.class)
public void testThrowsExceptionIfPayloadIsTooShort() throws Exception {
int payloadLength = 123;
byte[] b = new byte[RECORD_HEADER_LENGTH + payloadLength - 1];
b[0] = PROTOCOL_VERSION;
b[1] = ACK;
ByteUtils.writeUint16(payloadLength, b, 2);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.eof();
}
@Test(expected = FormatException.class)
public void testThrowsExceptionIfProtocolVersionIsUnrecognised()
throws Exception {
byte version = (byte) (PROTOCOL_VERSION + 1);
byte[] b = createRecord(version, ACK, new byte[0]);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.eof();
}
@Test(expected = FormatException.class)
public void testThrowsExceptionIfPayloadIsTooLong() throws Exception {
byte[] payload = new byte[MAX_RECORD_PAYLOAD_LENGTH + 1];
byte[] b = createRecord(PROTOCOL_VERSION, ACK, payload);
ByteArrayInputStream in = new ByteArrayInputStream(b);
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
reader.eof();
}
@Test
public void testSkipsUnrecognisedRecordTypes() throws Exception {
byte[] skip1 = createRecord(PROTOCOL_VERSION, (byte) (REQUEST + 1),
new byte[123]);
byte[] skip2 = createRecord(PROTOCOL_VERSION, (byte) (REQUEST + 2),
new byte[0]);
byte[] ack = createAck(false);
ByteArrayOutputStream input = new ByteArrayOutputStream();
input.write(skip1);
input.write(skip2);
input.write(ack);
ByteArrayInputStream in = new ByteArrayInputStream(input.toByteArray());
RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in);
assertTrue(reader.hasAck());
Ack a = reader.readAck();
assertEquals(MAX_MESSAGE_IDS, a.getMessageIds().size());
}
private byte[] createAck(boolean tooBig) throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
out.write(new byte[RECORD_HEADER_LENGTH]);
while (out.size() + UniqueId.LENGTH <= RECORD_HEADER_LENGTH
+ MAX_RECORD_PAYLOAD_LENGTH) {
out.write(TestUtils.getRandomId());
}
if (tooBig) out.write(TestUtils.getRandomId());
assertEquals(tooBig, out.size() > RECORD_HEADER_LENGTH +
MAX_RECORD_PAYLOAD_LENGTH);
byte[] record = out.toByteArray();
record[1] = ACK;
ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2);
return record;
return createRecord(PROTOCOL_VERSION, ACK, createPayload(tooBig));
}
private byte[] createEmptyAck() throws Exception {
byte[] record = new byte[RECORD_HEADER_LENGTH];
record[1] = ACK;
ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2);
return record;
return createRecord(PROTOCOL_VERSION, ACK, new byte[0]);
}
private byte[] createOffer(boolean tooBig) throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
out.write(new byte[RECORD_HEADER_LENGTH]);
while (out.size() + UniqueId.LENGTH <= RECORD_HEADER_LENGTH
+ MAX_RECORD_PAYLOAD_LENGTH) {
out.write(TestUtils.getRandomId());
}
if (tooBig) out.write(TestUtils.getRandomId());
assertEquals(tooBig, out.size() > RECORD_HEADER_LENGTH +
MAX_RECORD_PAYLOAD_LENGTH);
byte[] record = out.toByteArray();
record[1] = OFFER;
ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2);
return record;
return createRecord(PROTOCOL_VERSION, OFFER, createPayload(tooBig));
}
private byte[] createEmptyOffer() throws Exception {
byte[] record = new byte[RECORD_HEADER_LENGTH];
record[1] = OFFER;
ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2);
return record;
return createRecord(PROTOCOL_VERSION, OFFER, new byte[0]);
}
private byte[] createRequest(boolean tooBig) throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
out.write(new byte[RECORD_HEADER_LENGTH]);
while (out.size() + UniqueId.LENGTH <= RECORD_HEADER_LENGTH
+ MAX_RECORD_PAYLOAD_LENGTH) {
out.write(TestUtils.getRandomId());
}
if (tooBig) out.write(TestUtils.getRandomId());
assertEquals(tooBig, out.size() > RECORD_HEADER_LENGTH +
MAX_RECORD_PAYLOAD_LENGTH);
byte[] record = out.toByteArray();
record[1] = REQUEST;
ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2);
return record;
return createRecord(PROTOCOL_VERSION, REQUEST, createPayload(tooBig));
}
private byte[] createEmptyRequest() throws Exception {
byte[] record = new byte[RECORD_HEADER_LENGTH];
record[1] = REQUEST;
ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2);
return record;
return createRecord(PROTOCOL_VERSION, REQUEST, new byte[0]);
}
private byte[] createRecord(byte version, byte type, byte[] payload) {
byte[] b = new byte[RECORD_HEADER_LENGTH + payload.length];
b[0] = version;
b[1] = type;
ByteUtils.writeUint16(payload.length, b, 2);
System.arraycopy(payload, 0, b, RECORD_HEADER_LENGTH, payload.length);
return b;
}
private byte[] createPayload(boolean tooBig) throws Exception {
ByteArrayOutputStream payload = new ByteArrayOutputStream();
while (payload.size() + UniqueId.LENGTH <= MAX_RECORD_PAYLOAD_LENGTH) {
payload.write(TestUtils.getRandomId());
}
if (tooBig) payload.write(TestUtils.getRandomId());
assertEquals(tooBig, payload.size() > MAX_RECORD_PAYLOAD_LENGTH);
return payload.toByteArray();
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment