diff --git a/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java b/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java index d735145ac26f49a5d49e2169610f9124823e6de7..5a58743ff57a029c8d440e0b951323da01ba1796 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java @@ -98,28 +98,24 @@ class KeyAgreementTransport { private byte[] readRecord(byte expectedType) throws AbortException { while (true) { byte[] header = readHeader(); + byte version = header[0], type = header[1]; int len = ByteUtils.readUint16(header, RECORD_HEADER_PAYLOAD_LENGTH_OFFSET); - if (header[0] != PROTOCOL_VERSION) { - throw new AbortException(false); - } - byte type = header[1]; + // Reject unrecognised protocol version + if (version != PROTOCOL_VERSION) throw new AbortException(false); if (type == ABORT) throw new AbortException(true); - if (type != expectedType) { - if (type != KEY && type != CONFIRM) { - // ignore unrecognised record and try next - try { - readData(len); - } catch (IOException e) { - throw new AbortException(e); - } - continue; - } else { - throw new AbortException(false); + if (type == expectedType) { + try { + return readData(len); + } catch (IOException e) { + throw new AbortException(e); } } + // Reject recognised but unexpected record type + if (type == KEY || type == CONFIRM) throw new AbortException(false); + // Skip unrecognised record type try { - return readData(len); + readData(len); } catch (IOException e) { throw new AbortException(e); } diff --git a/bramble-core/src/main/java/org/briarproject/bramble/sync/RecordReaderImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/sync/RecordReaderImpl.java index 5fdb18b1bfa6f12e8493b8a0a3f03427e1563c2b..e6a80735ab450242e16b754bea3da49681c3d6b2 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/sync/RecordReaderImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/sync/RecordReaderImpl.java @@ -33,7 +33,7 @@ import static org.briarproject.bramble.api.sync.SyncConstants.RECORD_HEADER_LENG @NotNullByDefault class RecordReaderImpl implements RecordReader { - private enum State { BUFFER_EMPTY, BUFFER_FULL, EOF } + private enum State {BUFFER_EMPTY, BUFFER_FULL, EOF} private final MessageFactory messageFactory; private final InputStream in; @@ -64,10 +64,11 @@ class RecordReaderImpl implements RecordReader { } offset += read; } - // Check the protocol version - if (header[0] != PROTOCOL_VERSION) throw new FormatException(); - // Read the payload length + byte version = header[0], type = header[1]; payloadLength = ByteUtils.readUint16(header, 2); + // Check the protocol version + if (version != PROTOCOL_VERSION) throw new FormatException(); + // Check the payload length if (payloadLength > MAX_RECORD_PAYLOAD_LENGTH) throw new FormatException(); // Read the payload @@ -78,21 +79,21 @@ class RecordReaderImpl implements RecordReader { offset += read; } state = State.BUFFER_FULL; - // Return if this is a known record type - if (header[1] == ACK || header[1] == MESSAGE || - header[1] == OFFER || header[1] == REQUEST) { + // Return if this is a known record type, otherwise continue + if (type == ACK || type == MESSAGE || type == OFFER || + type == REQUEST) { return; } } } /** - * The return value indicates whether there's another record available - * or whether we've reached the end of the input stream. - * If a record is available, - * it's been read into the buffer by the time eof() returns, - * so the method that called eof() can access the record from the buffer, - * for example to check its type or extract its payload. + * Returns true if there's another record available or false if we've + * reached the end of the input stream. + * <p> + * If a record is available, it's been read into the buffer by the time + * eof() returns, so the method that called eof() can access the record + * from the buffer, for example to check its type or extract its payload. */ @Override public boolean eof() throws IOException { diff --git a/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java b/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java new file mode 100644 index 0000000000000000000000000000000000000000..5030f6931e3592618ed5abc7d73bae3550d45deb --- /dev/null +++ b/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java @@ -0,0 +1,251 @@ +package org.briarproject.bramble.keyagreement; + +import org.briarproject.bramble.api.keyagreement.KeyAgreementConnection; +import org.briarproject.bramble.api.plugin.TransportConnectionReader; +import org.briarproject.bramble.api.plugin.TransportConnectionWriter; +import org.briarproject.bramble.api.plugin.TransportId; +import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection; +import org.briarproject.bramble.test.BrambleMockTestCase; +import org.briarproject.bramble.test.TestUtils; +import org.briarproject.bramble.util.ByteUtils; +import org.jmock.Expectations; +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; + +import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION; +import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.RECORD_HEADER_LENGTH; +import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT; +import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM; +import static org.briarproject.bramble.api.keyagreement.RecordTypes.KEY; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +public class KeyAgreementTransportTest extends BrambleMockTestCase { + + private final DuplexTransportConnection duplexTransportConnection = + context.mock(DuplexTransportConnection.class); + private final TransportConnectionReader transportConnectionReader = + context.mock(TransportConnectionReader.class); + private final TransportConnectionWriter transportConnectionWriter = + context.mock(TransportConnectionWriter.class); + + private final TransportId transportId = new TransportId("test"); + private final KeyAgreementConnection keyAgreementConnection = + new KeyAgreementConnection(duplexTransportConnection, transportId); + + private ByteArrayInputStream inputStream; + private ByteArrayOutputStream outputStream; + private KeyAgreementTransport kat; + + @Test + public void testSendKey() throws Exception { + setup(new byte[0]); + byte[] key = TestUtils.getRandomBytes(123); + kat.sendKey(key); + assertRecordSent(KEY, key); + } + + @Test + public void testSendConfirm() throws Exception { + setup(new byte[0]); + byte[] confirm = TestUtils.getRandomBytes(123); + kat.sendConfirm(confirm); + assertRecordSent(CONFIRM, confirm); + } + + @Test + public void testSendAbortWithException() throws Exception { + setup(new byte[0]); + context.checking(new Expectations() {{ + oneOf(transportConnectionReader).dispose(true, true); + oneOf(transportConnectionWriter).dispose(true); + }}); + kat.sendAbort(true); + assertRecordSent(ABORT, new byte[0]); + } + + @Test + public void testSendAbortWithoutException() throws Exception { + setup(new byte[0]); + context.checking(new Expectations() {{ + oneOf(transportConnectionReader).dispose(false, true); + oneOf(transportConnectionWriter).dispose(false); + }}); + kat.sendAbort(false); + assertRecordSent(ABORT, new byte[0]); + } + + @Test(expected = AbortException.class) + public void testReceiveKeyThrowsExceptionIfAtEndOfStream() + throws Exception { + setup(new byte[0]); + kat.receiveKey(); + } + + @Test(expected = AbortException.class) + public void testReceiveKeyThrowsExceptionIfHeaderIsTooShort() + throws Exception { + byte[] input = new byte[RECORD_HEADER_LENGTH - 1]; + input[0] = PROTOCOL_VERSION; + input[1] = KEY; + setup(input); + kat.receiveKey(); + } + + @Test(expected = AbortException.class) + public void testReceiveKeyThrowsExceptionIfPayloadIsTooShort() + throws Exception { + int payloadLength = 123; + byte[] input = new byte[RECORD_HEADER_LENGTH + payloadLength - 1]; + input[0] = PROTOCOL_VERSION; + input[1] = KEY; + ByteUtils.writeUint16(payloadLength, input, 2); + setup(input); + kat.receiveKey(); + } + + @Test(expected = AbortException.class) + public void testReceiveKeyThrowsExceptionIfProtocolVersionIsUnrecognised() + throws Exception { + setup(createRecord((byte) (PROTOCOL_VERSION + 1), KEY, new byte[123])); + kat.receiveKey(); + } + + @Test(expected = AbortException.class) + public void testReceiveKeyThrowsExceptionIfAbortIsReceived() + throws Exception { + setup(createRecord(PROTOCOL_VERSION, ABORT, new byte[0])); + kat.receiveKey(); + } + + @Test(expected = AbortException.class) + public void testReceiveKeyThrowsExceptionIfConfirmIsReceived() + throws Exception { + setup(createRecord(PROTOCOL_VERSION, CONFIRM, new byte[123])); + kat.receiveKey(); + } + + @Test + public void testReceiveKeySkipsUnrecognisedRecordTypes() throws Exception { + byte[] skip1 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 1), + new byte[123]); + byte[] skip2 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 2), + new byte[0]); + byte[] payload = TestUtils.getRandomBytes(123); + byte[] key = createRecord(PROTOCOL_VERSION, KEY, payload); + ByteArrayOutputStream input = new ByteArrayOutputStream(); + input.write(skip1); + input.write(skip2); + input.write(key); + setup(input.toByteArray()); + assertArrayEquals(payload, kat.receiveKey()); + } + + @Test(expected = AbortException.class) + public void testReceiveConfirmThrowsExceptionIfAtEndOfStream() + throws Exception { + setup(new byte[0]); + kat.receiveConfirm(); + } + + @Test(expected = AbortException.class) + public void testReceiveConfirmThrowsExceptionIfHeaderIsTooShort() + throws Exception { + byte[] input = new byte[RECORD_HEADER_LENGTH - 1]; + input[0] = PROTOCOL_VERSION; + input[1] = CONFIRM; + setup(input); + kat.receiveConfirm(); + } + + @Test(expected = AbortException.class) + public void testReceiveConfirmThrowsExceptionIfPayloadIsTooShort() + throws Exception { + int payloadLength = 123; + byte[] input = new byte[RECORD_HEADER_LENGTH + payloadLength - 1]; + input[0] = PROTOCOL_VERSION; + input[1] = CONFIRM; + ByteUtils.writeUint16(payloadLength, input, 2); + setup(input); + kat.receiveConfirm(); + } + + @Test(expected = AbortException.class) + public void testReceiveConfirmThrowsExceptionIfProtocolVersionIsUnrecognised() + throws Exception { + setup(createRecord((byte) (PROTOCOL_VERSION + 1), CONFIRM, + new byte[123])); + kat.receiveConfirm(); + } + + @Test(expected = AbortException.class) + public void testReceiveConfirmThrowsExceptionIfAbortIsReceived() + throws Exception { + setup(createRecord(PROTOCOL_VERSION, ABORT, new byte[0])); + kat.receiveConfirm(); + } + + @Test(expected = AbortException.class) + public void testReceiveKeyThrowsExceptionIfKeyIsReceived() + throws Exception { + setup(createRecord(PROTOCOL_VERSION, KEY, new byte[123])); + kat.receiveConfirm(); + } + + @Test + public void testReceiveConfirmSkipsUnrecognisedRecordTypes() + throws Exception { + byte[] skip1 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 1), + new byte[123]); + byte[] skip2 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 2), + new byte[0]); + byte[] payload = TestUtils.getRandomBytes(123); + byte[] confirm = createRecord(PROTOCOL_VERSION, CONFIRM, payload); + ByteArrayOutputStream input = new ByteArrayOutputStream(); + input.write(skip1); + input.write(skip2); + input.write(confirm); + setup(input.toByteArray()); + assertArrayEquals(payload, kat.receiveConfirm()); + } + + private void setup(byte[] input) throws Exception { + inputStream = new ByteArrayInputStream(input); + outputStream = new ByteArrayOutputStream(); + context.checking(new Expectations() {{ + allowing(duplexTransportConnection).getReader(); + will(returnValue(transportConnectionReader)); + allowing(transportConnectionReader).getInputStream(); + will(returnValue(inputStream)); + allowing(duplexTransportConnection).getWriter(); + will(returnValue(transportConnectionWriter)); + allowing(transportConnectionWriter).getOutputStream(); + will(returnValue(outputStream)); + }}); + kat = new KeyAgreementTransport(keyAgreementConnection); + } + + private void assertRecordSent(byte expectedType, byte[] expectedPayload) { + byte[] output = outputStream.toByteArray(); + assertEquals(RECORD_HEADER_LENGTH + expectedPayload.length, + output.length); + assertEquals(PROTOCOL_VERSION, output[0]); + assertEquals(expectedType, output[1]); + assertEquals(expectedPayload.length, ByteUtils.readUint16(output, 2)); + byte[] payload = new byte[output.length - RECORD_HEADER_LENGTH]; + System.arraycopy(output, RECORD_HEADER_LENGTH, payload, 0, + payload.length); + assertArrayEquals(expectedPayload, payload); + } + + private byte[] createRecord(byte version, byte type, byte[] payload) { + byte[] b = new byte[RECORD_HEADER_LENGTH + payload.length]; + b[0] = version; + b[1] = type; + ByteUtils.writeUint16(payload.length, b, 2); + System.arraycopy(payload, 0, b, RECORD_HEADER_LENGTH, payload.length); + return b; + } +} diff --git a/bramble-core/src/test/java/org/briarproject/bramble/sync/RecordReaderImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/sync/RecordReaderImplTest.java index 8d1fca5154b4883f651c135be5a6eb19900afc21..7b927cf59b3aad4211f1c8e00cdcfd746f253ba6 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/sync/RecordReaderImplTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/sync/RecordReaderImplTest.java @@ -2,7 +2,9 @@ package org.briarproject.bramble.sync; import org.briarproject.bramble.api.FormatException; import org.briarproject.bramble.api.UniqueId; -import org.briarproject.bramble.test.BrambleTestCase; +import org.briarproject.bramble.api.sync.Ack; +import org.briarproject.bramble.api.sync.MessageFactory; +import org.briarproject.bramble.test.BrambleMockTestCase; import org.briarproject.bramble.test.TestUtils; import org.briarproject.bramble.util.ByteUtils; import org.junit.Test; @@ -13,17 +15,24 @@ import java.io.ByteArrayOutputStream; import static org.briarproject.bramble.api.sync.RecordTypes.ACK; import static org.briarproject.bramble.api.sync.RecordTypes.OFFER; import static org.briarproject.bramble.api.sync.RecordTypes.REQUEST; +import static org.briarproject.bramble.api.sync.SyncConstants.MAX_MESSAGE_IDS; import static org.briarproject.bramble.api.sync.SyncConstants.MAX_RECORD_PAYLOAD_LENGTH; +import static org.briarproject.bramble.api.sync.SyncConstants.PROTOCOL_VERSION; import static org.briarproject.bramble.api.sync.SyncConstants.RECORD_HEADER_LENGTH; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; -public class RecordReaderImplTest extends BrambleTestCase { +public class RecordReaderImplTest extends BrambleMockTestCase { + + private final MessageFactory messageFactory = + context.mock(MessageFactory.class); @Test(expected = FormatException.class) public void testFormatExceptionIfAckIsTooLarge() throws Exception { byte[] b = createAck(true); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readAck(); } @@ -31,15 +40,15 @@ public class RecordReaderImplTest extends BrambleTestCase { public void testNoFormatExceptionIfAckIsMaximumSize() throws Exception { byte[] b = createAck(false); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readAck(); } @Test(expected = FormatException.class) - public void testEmptyAck() throws Exception { + public void testFormatExceptionIfAckIsEmpty() throws Exception { byte[] b = createEmptyAck(); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readAck(); } @@ -47,7 +56,7 @@ public class RecordReaderImplTest extends BrambleTestCase { public void testFormatExceptionIfOfferIsTooLarge() throws Exception { byte[] b = createOffer(true); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readOffer(); } @@ -55,15 +64,15 @@ public class RecordReaderImplTest extends BrambleTestCase { public void testNoFormatExceptionIfOfferIsMaximumSize() throws Exception { byte[] b = createOffer(false); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readOffer(); } @Test(expected = FormatException.class) - public void testEmptyOffer() throws Exception { + public void testFormatExceptionIfOfferIsEmpty() throws Exception { byte[] b = createEmptyOffer(); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readOffer(); } @@ -71,7 +80,7 @@ public class RecordReaderImplTest extends BrambleTestCase { public void testFormatExceptionIfRequestIsTooLarge() throws Exception { byte[] b = createRequest(true); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readRequest(); } @@ -79,84 +88,132 @@ public class RecordReaderImplTest extends BrambleTestCase { public void testNoFormatExceptionIfRequestIsMaximumSize() throws Exception { byte[] b = createRequest(false); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readRequest(); } @Test(expected = FormatException.class) - public void testEmptyRequest() throws Exception { + public void testFormatExceptionIfRequestIsEmpty() throws Exception { byte[] b = createEmptyRequest(); ByteArrayInputStream in = new ByteArrayInputStream(b); - RecordReaderImpl reader = new RecordReaderImpl(null, in); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); reader.readRequest(); } + @Test + public void testEofReturnsTrueWhenAtEndOfStream() throws Exception { + ByteArrayInputStream in = new ByteArrayInputStream(new byte[0]); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); + assertTrue(reader.eof()); + } + + @Test + public void testEofReturnsFalseWhenNotAtEndOfStream() throws Exception { + byte[] b = createAck(false); + ByteArrayInputStream in = new ByteArrayInputStream(b); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); + assertFalse(reader.eof()); + } + + @Test(expected = FormatException.class) + public void testThrowsExceptionIfHeaderIsTooShort() throws Exception { + byte[] b = new byte[RECORD_HEADER_LENGTH - 1]; + b[0] = PROTOCOL_VERSION; + b[1] = ACK; + ByteArrayInputStream in = new ByteArrayInputStream(b); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); + reader.eof(); + } + + @Test(expected = FormatException.class) + public void testThrowsExceptionIfPayloadIsTooShort() throws Exception { + int payloadLength = 123; + byte[] b = new byte[RECORD_HEADER_LENGTH + payloadLength - 1]; + b[0] = PROTOCOL_VERSION; + b[1] = ACK; + ByteUtils.writeUint16(payloadLength, b, 2); + ByteArrayInputStream in = new ByteArrayInputStream(b); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); + reader.eof(); + } + + @Test(expected = FormatException.class) + public void testThrowsExceptionIfProtocolVersionIsUnrecognised() + throws Exception { + byte version = (byte) (PROTOCOL_VERSION + 1); + byte[] b = createRecord(version, ACK, new byte[0]); + ByteArrayInputStream in = new ByteArrayInputStream(b); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); + reader.eof(); + } + + @Test(expected = FormatException.class) + public void testThrowsExceptionIfPayloadIsTooLong() throws Exception { + byte[] payload = new byte[MAX_RECORD_PAYLOAD_LENGTH + 1]; + byte[] b = createRecord(PROTOCOL_VERSION, ACK, payload); + ByteArrayInputStream in = new ByteArrayInputStream(b); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); + reader.eof(); + } + + @Test + public void testSkipsUnrecognisedRecordTypes() throws Exception { + byte[] skip1 = createRecord(PROTOCOL_VERSION, (byte) (REQUEST + 1), + new byte[123]); + byte[] skip2 = createRecord(PROTOCOL_VERSION, (byte) (REQUEST + 2), + new byte[0]); + byte[] ack = createAck(false); + ByteArrayOutputStream input = new ByteArrayOutputStream(); + input.write(skip1); + input.write(skip2); + input.write(ack); + ByteArrayInputStream in = new ByteArrayInputStream(input.toByteArray()); + RecordReaderImpl reader = new RecordReaderImpl(messageFactory, in); + assertTrue(reader.hasAck()); + Ack a = reader.readAck(); + assertEquals(MAX_MESSAGE_IDS, a.getMessageIds().size()); + } + private byte[] createAck(boolean tooBig) throws Exception { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - out.write(new byte[RECORD_HEADER_LENGTH]); - while (out.size() + UniqueId.LENGTH <= RECORD_HEADER_LENGTH - + MAX_RECORD_PAYLOAD_LENGTH) { - out.write(TestUtils.getRandomId()); - } - if (tooBig) out.write(TestUtils.getRandomId()); - assertEquals(tooBig, out.size() > RECORD_HEADER_LENGTH + - MAX_RECORD_PAYLOAD_LENGTH); - byte[] record = out.toByteArray(); - record[1] = ACK; - ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2); - return record; + return createRecord(PROTOCOL_VERSION, ACK, createPayload(tooBig)); } private byte[] createEmptyAck() throws Exception { - byte[] record = new byte[RECORD_HEADER_LENGTH]; - record[1] = ACK; - ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2); - return record; + return createRecord(PROTOCOL_VERSION, ACK, new byte[0]); } private byte[] createOffer(boolean tooBig) throws Exception { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - out.write(new byte[RECORD_HEADER_LENGTH]); - while (out.size() + UniqueId.LENGTH <= RECORD_HEADER_LENGTH - + MAX_RECORD_PAYLOAD_LENGTH) { - out.write(TestUtils.getRandomId()); - } - if (tooBig) out.write(TestUtils.getRandomId()); - assertEquals(tooBig, out.size() > RECORD_HEADER_LENGTH + - MAX_RECORD_PAYLOAD_LENGTH); - byte[] record = out.toByteArray(); - record[1] = OFFER; - ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2); - return record; + return createRecord(PROTOCOL_VERSION, OFFER, createPayload(tooBig)); } private byte[] createEmptyOffer() throws Exception { - byte[] record = new byte[RECORD_HEADER_LENGTH]; - record[1] = OFFER; - ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2); - return record; + return createRecord(PROTOCOL_VERSION, OFFER, new byte[0]); } private byte[] createRequest(boolean tooBig) throws Exception { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - out.write(new byte[RECORD_HEADER_LENGTH]); - while (out.size() + UniqueId.LENGTH <= RECORD_HEADER_LENGTH - + MAX_RECORD_PAYLOAD_LENGTH) { - out.write(TestUtils.getRandomId()); - } - if (tooBig) out.write(TestUtils.getRandomId()); - assertEquals(tooBig, out.size() > RECORD_HEADER_LENGTH + - MAX_RECORD_PAYLOAD_LENGTH); - byte[] record = out.toByteArray(); - record[1] = REQUEST; - ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2); - return record; + return createRecord(PROTOCOL_VERSION, REQUEST, createPayload(tooBig)); } private byte[] createEmptyRequest() throws Exception { - byte[] record = new byte[RECORD_HEADER_LENGTH]; - record[1] = REQUEST; - ByteUtils.writeUint16(record.length - RECORD_HEADER_LENGTH, record, 2); - return record; + return createRecord(PROTOCOL_VERSION, REQUEST, new byte[0]); + } + + private byte[] createRecord(byte version, byte type, byte[] payload) { + byte[] b = new byte[RECORD_HEADER_LENGTH + payload.length]; + b[0] = version; + b[1] = type; + ByteUtils.writeUint16(payload.length, b, 2); + System.arraycopy(payload, 0, b, RECORD_HEADER_LENGTH, payload.length); + return b; + } + + private byte[] createPayload(boolean tooBig) throws Exception { + ByteArrayOutputStream payload = new ByteArrayOutputStream(); + while (payload.size() + UniqueId.LENGTH <= MAX_RECORD_PAYLOAD_LENGTH) { + payload.write(TestUtils.getRandomId()); + } + if (tooBig) payload.write(TestUtils.getRandomId()); + assertEquals(tooBig, payload.size() > MAX_RECORD_PAYLOAD_LENGTH); + return payload.toByteArray(); } }