diff --git a/bramble-core/src/main/java/org/briarproject/bramble/BrambleCoreModule.java b/bramble-core/src/main/java/org/briarproject/bramble/BrambleCoreModule.java index b92329b187ea87e63ae703849c4de08f8d255a35..f4ece8bd4903288608ac5b8cf3594de90bfa5340 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/BrambleCoreModule.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/BrambleCoreModule.java @@ -13,6 +13,7 @@ import org.briarproject.bramble.keyagreement.KeyAgreementModule; import org.briarproject.bramble.lifecycle.LifecycleModule; import org.briarproject.bramble.plugin.PluginModule; import org.briarproject.bramble.properties.PropertiesModule; +import org.briarproject.bramble.record.RecordModule; import org.briarproject.bramble.reliability.ReliabilityModule; import org.briarproject.bramble.reporting.ReportingModule; import org.briarproject.bramble.settings.SettingsModule; @@ -38,6 +39,7 @@ import dagger.Module; LifecycleModule.class, PluginModule.class, PropertiesModule.class, + RecordModule.class, ReliabilityModule.class, ReportingModule.class, SettingsModule.class, diff --git a/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementConnector.java b/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementConnector.java index 079b9225317c4a8107d735685f0b02cbabdfe891..0c24f1c9f69d99105ba0255feefb7fc0a1ddadfe 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementConnector.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementConnector.java @@ -13,6 +13,8 @@ import org.briarproject.bramble.api.plugin.PluginManager; import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.duplex.DuplexPlugin; import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection; +import org.briarproject.bramble.api.record.RecordReaderFactory; +import org.briarproject.bramble.api.record.RecordWriterFactory; import java.io.IOException; import java.io.InputStream; @@ -44,6 +46,8 @@ class KeyAgreementConnector { private final KeyAgreementCrypto keyAgreementCrypto; private final PluginManager pluginManager; private final ConnectionChooser connectionChooser; + private final RecordReaderFactory recordReaderFactory; + private final RecordWriterFactory recordWriterFactory; private final List<KeyAgreementListener> listeners = new CopyOnWriteArrayList<>(); @@ -54,11 +58,15 @@ class KeyAgreementConnector { KeyAgreementConnector(Callbacks callbacks, KeyAgreementCrypto keyAgreementCrypto, PluginManager pluginManager, - ConnectionChooser connectionChooser) { + ConnectionChooser connectionChooser, + RecordReaderFactory recordReaderFactory, + RecordWriterFactory recordWriterFactory) { this.callbacks = callbacks; this.keyAgreementCrypto = keyAgreementCrypto; this.pluginManager = pluginManager; this.connectionChooser = connectionChooser; + this.recordReaderFactory = recordReaderFactory; + this.recordWriterFactory = recordWriterFactory; } Payload listen(KeyPair localKeyPair) { @@ -119,7 +127,8 @@ class KeyAgreementConnector { KeyAgreementConnection chosen = connectionChooser.poll(CONNECTION_TIMEOUT); if (chosen == null) return null; - return new KeyAgreementTransport(chosen); + return new KeyAgreementTransport(recordReaderFactory, + recordWriterFactory, chosen); } catch (InterruptedException e) { LOG.info("Interrupted while waiting for connection"); Thread.currentThread().interrupt(); diff --git a/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTaskImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTaskImpl.java index 1b3d9abea84d7237e58bb09a51ccf04e07182537..b5fc7ef7c0fcd091b0e44f3b818481c93c2d4400 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTaskImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTaskImpl.java @@ -19,6 +19,8 @@ import org.briarproject.bramble.api.keyagreement.event.KeyAgreementWaitingEvent; import org.briarproject.bramble.api.nullsafety.MethodsNotNullByDefault; import org.briarproject.bramble.api.nullsafety.ParametersNotNullByDefault; import org.briarproject.bramble.api.plugin.PluginManager; +import org.briarproject.bramble.api.record.RecordReaderFactory; +import org.briarproject.bramble.api.record.RecordWriterFactory; import java.io.IOException; import java.util.logging.Logger; @@ -49,14 +51,17 @@ class KeyAgreementTaskImpl extends Thread implements KeyAgreementTask, KeyAgreementTaskImpl(CryptoComponent crypto, KeyAgreementCrypto keyAgreementCrypto, EventBus eventBus, PayloadEncoder payloadEncoder, PluginManager pluginManager, - ConnectionChooser connectionChooser) { + ConnectionChooser connectionChooser, + RecordReaderFactory recordReaderFactory, + RecordWriterFactory recordWriterFactory) { this.crypto = crypto; this.keyAgreementCrypto = keyAgreementCrypto; this.eventBus = eventBus; this.payloadEncoder = payloadEncoder; localKeyPair = crypto.generateAgreementKeyPair(); connector = new KeyAgreementConnector(this, keyAgreementCrypto, - pluginManager, connectionChooser); + pluginManager, connectionChooser, recordReaderFactory, + recordWriterFactory); } @Override diff --git a/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java b/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java index 5a58743ff57a029c8d440e0b951323da01ba1796..c545fbd69e8c6662e2edc16c73edefa5e12fc16f 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/keyagreement/KeyAgreementTransport.java @@ -4,9 +4,12 @@ import org.briarproject.bramble.api.keyagreement.KeyAgreementConnection; import org.briarproject.bramble.api.nullsafety.NotNullByDefault; import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection; -import org.briarproject.bramble.util.ByteUtils; +import org.briarproject.bramble.api.record.Record; +import org.briarproject.bramble.api.record.RecordReader; +import org.briarproject.bramble.api.record.RecordReaderFactory; +import org.briarproject.bramble.api.record.RecordWriter; +import org.briarproject.bramble.api.record.RecordWriterFactory; -import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -14,8 +17,6 @@ import java.util.logging.Logger; import static java.util.logging.Level.WARNING; import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION; -import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.RECORD_HEADER_LENGTH; -import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.RECORD_HEADER_PAYLOAD_LENGTH_OFFSET; import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT; import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM; import static org.briarproject.bramble.api.keyagreement.RecordTypes.KEY; @@ -30,14 +31,17 @@ class KeyAgreementTransport { Logger.getLogger(KeyAgreementTransport.class.getName()); private final KeyAgreementConnection kac; - private final InputStream in; - private final OutputStream out; + private final RecordReader reader; + private final RecordWriter writer; - KeyAgreementTransport(KeyAgreementConnection kac) + KeyAgreementTransport(RecordReaderFactory recordReaderFactory, + RecordWriterFactory recordWriterFactory, KeyAgreementConnection kac) throws IOException { this.kac = kac; - in = kac.getConnection().getReader().getInputStream(); - out = kac.getConnection().getWriter().getOutputStream(); + InputStream in = kac.getConnection().getReader().getInputStream(); + reader = recordReaderFactory.createRecordReader(in); + OutputStream out = kac.getConnection().getWriter().getOutputStream(); + writer = recordWriterFactory.createRecordWriter(out); } public DuplexTransportConnection getConnection() { @@ -74,9 +78,8 @@ class KeyAgreementTransport { tryToClose(exception); } - public void tryToClose(boolean exception) { + private void tryToClose(boolean exception) { try { - LOG.info("Closing connection"); kac.getConnection().getReader().dispose(exception, true); kac.getConnection().getWriter().dispose(exception); } catch (IOException e) { @@ -85,59 +88,27 @@ class KeyAgreementTransport { } private void writeRecord(byte type, byte[] payload) throws IOException { - byte[] recordHeader = new byte[RECORD_HEADER_LENGTH]; - recordHeader[0] = PROTOCOL_VERSION; - recordHeader[1] = type; - ByteUtils.writeUint16(payload.length, recordHeader, - RECORD_HEADER_PAYLOAD_LENGTH_OFFSET); - out.write(recordHeader); - out.write(payload); - out.flush(); + writer.writeRecord(new Record(PROTOCOL_VERSION, type, payload)); + writer.flush(); } private byte[] readRecord(byte expectedType) throws AbortException { while (true) { - byte[] header = readHeader(); - byte version = header[0], type = header[1]; - int len = ByteUtils.readUint16(header, - RECORD_HEADER_PAYLOAD_LENGTH_OFFSET); - // Reject unrecognised protocol version - if (version != PROTOCOL_VERSION) throw new AbortException(false); - if (type == ABORT) throw new AbortException(true); - if (type == expectedType) { - try { - return readData(len); - } catch (IOException e) { - throw new AbortException(e); - } - } - // Reject recognised but unexpected record type - if (type == KEY || type == CONFIRM) throw new AbortException(false); - // Skip unrecognised record type try { - readData(len); + Record record = reader.readRecord(); + // Reject unrecognised protocol version + if (record.getProtocolVersion() != PROTOCOL_VERSION) + throw new AbortException(false); + byte type = record.getRecordType(); + if (type == ABORT) throw new AbortException(true); + if (type == expectedType) return record.getPayload(); + // Reject recognised but unexpected record type + if (type == KEY || type == CONFIRM) + throw new AbortException(false); + // Skip unrecognised record type } catch (IOException e) { throw new AbortException(e); } } } - - private byte[] readHeader() throws AbortException { - try { - return readData(RECORD_HEADER_LENGTH); - } catch (IOException e) { - throw new AbortException(e); - } - } - - private byte[] readData(int len) throws IOException { - byte[] data = new byte[len]; - int offset = 0; - while (offset < data.length) { - int read = in.read(data, offset, data.length - offset); - if (read == -1) throw new EOFException(); - offset += read; - } - return data; - } } diff --git a/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java b/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java index b42fbbd25442572cecb1bbcf3cacff9ba16dd64f..cddb5b40ec99f57b350c4ec89ef888554ac748a3 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/keyagreement/KeyAgreementTransportTest.java @@ -5,23 +5,31 @@ import org.briarproject.bramble.api.plugin.TransportConnectionReader; import org.briarproject.bramble.api.plugin.TransportConnectionWriter; import org.briarproject.bramble.api.plugin.TransportId; import org.briarproject.bramble.api.plugin.duplex.DuplexTransportConnection; +import org.briarproject.bramble.api.record.Record; +import org.briarproject.bramble.api.record.RecordReader; +import org.briarproject.bramble.api.record.RecordReaderFactory; +import org.briarproject.bramble.api.record.RecordWriter; +import org.briarproject.bramble.api.record.RecordWriterFactory; import org.briarproject.bramble.test.BrambleMockTestCase; -import org.briarproject.bramble.test.TestUtils; -import org.briarproject.bramble.util.ByteUtils; +import org.briarproject.bramble.test.CaptureArgumentAction; import org.jmock.Expectations; +import org.jmock.lib.legacy.ClassImposteriser; import org.junit.Test; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; +import java.io.EOFException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.concurrent.atomic.AtomicReference; import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.PROTOCOL_VERSION; -import static org.briarproject.bramble.api.keyagreement.KeyAgreementConstants.RECORD_HEADER_LENGTH; import static org.briarproject.bramble.api.keyagreement.RecordTypes.ABORT; import static org.briarproject.bramble.api.keyagreement.RecordTypes.CONFIRM; import static org.briarproject.bramble.api.keyagreement.RecordTypes.KEY; +import static org.briarproject.bramble.test.TestUtils.getRandomBytes; import static org.briarproject.bramble.test.TestUtils.getTransportId; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; public class KeyAgreementTransportTest extends BrambleMockTestCase { @@ -31,222 +39,268 @@ public class KeyAgreementTransportTest extends BrambleMockTestCase { context.mock(TransportConnectionReader.class); private final TransportConnectionWriter transportConnectionWriter = context.mock(TransportConnectionWriter.class); + private final RecordReaderFactory recordReaderFactory = + context.mock(RecordReaderFactory.class); + private final RecordWriterFactory recordWriterFactory = + context.mock(RecordWriterFactory.class); + private final RecordReader recordReader = context.mock(RecordReader.class); + private final RecordWriter recordWriter = context.mock(RecordWriter.class); private final TransportId transportId = getTransportId(); private final KeyAgreementConnection keyAgreementConnection = new KeyAgreementConnection(duplexTransportConnection, transportId); - private ByteArrayInputStream inputStream; - private ByteArrayOutputStream outputStream; + private final InputStream inputStream; + private final OutputStream outputStream; + private KeyAgreementTransport kat; + public KeyAgreementTransportTest() { + context.setImposteriser(ClassImposteriser.INSTANCE); + inputStream = context.mock(InputStream.class); + outputStream = context.mock(OutputStream.class); + } + @Test public void testSendKey() throws Exception { - setup(new byte[0]); - byte[] key = TestUtils.getRandomBytes(123); + byte[] key = getRandomBytes(123); + + setup(); + AtomicReference<Record> written = expectWriteRecord(); + kat.sendKey(key); - assertRecordSent(KEY, key); + assertNotNull(written.get()); + assertRecordEquals(PROTOCOL_VERSION, KEY, key, written.get()); } @Test public void testSendConfirm() throws Exception { - setup(new byte[0]); - byte[] confirm = TestUtils.getRandomBytes(123); + byte[] confirm = getRandomBytes(123); + + setup(); + AtomicReference<Record> written = expectWriteRecord(); + kat.sendConfirm(confirm); - assertRecordSent(CONFIRM, confirm); + assertNotNull(written.get()); + assertRecordEquals(PROTOCOL_VERSION, CONFIRM, confirm, written.get()); } @Test public void testSendAbortWithException() throws Exception { - setup(new byte[0]); + setup(); + AtomicReference<Record> written = expectWriteRecord(); context.checking(new Expectations() {{ oneOf(transportConnectionReader).dispose(true, true); oneOf(transportConnectionWriter).dispose(true); }}); + kat.sendAbort(true); - assertRecordSent(ABORT, new byte[0]); + assertNotNull(written.get()); + assertRecordEquals(PROTOCOL_VERSION, ABORT, new byte[0], written.get()); } @Test public void testSendAbortWithoutException() throws Exception { - setup(new byte[0]); + setup(); + AtomicReference<Record> written = expectWriteRecord(); context.checking(new Expectations() {{ oneOf(transportConnectionReader).dispose(false, true); oneOf(transportConnectionWriter).dispose(false); }}); + kat.sendAbort(false); - assertRecordSent(ABORT, new byte[0]); + assertNotNull(written.get()); + assertRecordEquals(PROTOCOL_VERSION, ABORT, new byte[0], written.get()); } @Test(expected = AbortException.class) public void testReceiveKeyThrowsExceptionIfAtEndOfStream() throws Exception { - setup(new byte[0]); - kat.receiveKey(); - } + setup(); + context.checking(new Expectations() {{ + oneOf(recordReader).readRecord(); + will(throwException(new EOFException())); + }}); - @Test(expected = AbortException.class) - public void testReceiveKeyThrowsExceptionIfHeaderIsTooShort() - throws Exception { - byte[] input = new byte[RECORD_HEADER_LENGTH - 1]; - input[0] = PROTOCOL_VERSION; - input[1] = KEY; - setup(input); - kat.receiveKey(); - } - - @Test(expected = AbortException.class) - public void testReceiveKeyThrowsExceptionIfPayloadIsTooShort() - throws Exception { - int payloadLength = 123; - byte[] input = new byte[RECORD_HEADER_LENGTH + payloadLength - 1]; - input[0] = PROTOCOL_VERSION; - input[1] = KEY; - ByteUtils.writeUint16(payloadLength, input, 2); - setup(input); kat.receiveKey(); } @Test(expected = AbortException.class) public void testReceiveKeyThrowsExceptionIfProtocolVersionIsUnrecognised() throws Exception { - setup(createRecord((byte) (PROTOCOL_VERSION + 1), KEY, new byte[123])); + byte unknownVersion = (byte) (PROTOCOL_VERSION + 1); + byte[] key = getRandomBytes(123); + + setup(); + context.checking(new Expectations() {{ + oneOf(recordReader).readRecord(); + will(returnValue(new Record(unknownVersion, KEY, key))); + }}); + kat.receiveKey(); } @Test(expected = AbortException.class) public void testReceiveKeyThrowsExceptionIfAbortIsReceived() throws Exception { - setup(createRecord(PROTOCOL_VERSION, ABORT, new byte[0])); + setup(); + context.checking(new Expectations() {{ + oneOf(recordReader).readRecord(); + will(returnValue(new Record(PROTOCOL_VERSION, ABORT, new byte[0]))); + }}); + kat.receiveKey(); } @Test(expected = AbortException.class) public void testReceiveKeyThrowsExceptionIfConfirmIsReceived() throws Exception { - setup(createRecord(PROTOCOL_VERSION, CONFIRM, new byte[123])); + byte[] confirm = getRandomBytes(123); + + setup(); + context.checking(new Expectations() {{ + oneOf(recordReader).readRecord(); + will(returnValue(new Record(PROTOCOL_VERSION, CONFIRM, confirm))); + }}); + kat.receiveKey(); } @Test public void testReceiveKeySkipsUnrecognisedRecordTypes() throws Exception { - byte[] skip1 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 1), - new byte[123]); - byte[] skip2 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 2), - new byte[0]); - byte[] payload = TestUtils.getRandomBytes(123); - byte[] key = createRecord(PROTOCOL_VERSION, KEY, payload); - ByteArrayOutputStream input = new ByteArrayOutputStream(); - input.write(skip1); - input.write(skip2); - input.write(key); - setup(input.toByteArray()); - assertArrayEquals(payload, kat.receiveKey()); - } + byte type1 = (byte) (ABORT + 1); + byte[] payload1 = getRandomBytes(123); + Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1); + byte type2 = (byte) (ABORT + 2); + byte[] payload2 = new byte[0]; + Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2); + byte[] key = getRandomBytes(123); + Record keyRecord = new Record(PROTOCOL_VERSION, KEY, key); - @Test(expected = AbortException.class) - public void testReceiveConfirmThrowsExceptionIfAtEndOfStream() - throws Exception { - setup(new byte[0]); - kat.receiveConfirm(); - } + setup(); + context.checking(new Expectations() {{ + oneOf(recordReader).readRecord(); + will(returnValue(unknownRecord1)); + oneOf(recordReader).readRecord(); + will(returnValue(unknownRecord2)); + oneOf(recordReader).readRecord(); + will(returnValue(keyRecord)); + }}); - @Test(expected = AbortException.class) - public void testReceiveConfirmThrowsExceptionIfHeaderIsTooShort() - throws Exception { - byte[] input = new byte[RECORD_HEADER_LENGTH - 1]; - input[0] = PROTOCOL_VERSION; - input[1] = CONFIRM; - setup(input); - kat.receiveConfirm(); + assertArrayEquals(key, kat.receiveKey()); } @Test(expected = AbortException.class) - public void testReceiveConfirmThrowsExceptionIfPayloadIsTooShort() + public void testReceiveConfirmThrowsExceptionIfAtEndOfStream() throws Exception { - int payloadLength = 123; - byte[] input = new byte[RECORD_HEADER_LENGTH + payloadLength - 1]; - input[0] = PROTOCOL_VERSION; - input[1] = CONFIRM; - ByteUtils.writeUint16(payloadLength, input, 2); - setup(input); + setup(); + context.checking(new Expectations() {{ + oneOf(recordReader).readRecord(); + will(throwException(new EOFException())); + }}); + kat.receiveConfirm(); } @Test(expected = AbortException.class) public void testReceiveConfirmThrowsExceptionIfProtocolVersionIsUnrecognised() throws Exception { - setup(createRecord((byte) (PROTOCOL_VERSION + 1), CONFIRM, - new byte[123])); + byte unknownVersion = (byte) (PROTOCOL_VERSION + 1); + byte[] confirm = getRandomBytes(123); + + setup(); + context.checking(new Expectations() {{ + oneOf(recordReader).readRecord(); + will(returnValue(new Record(unknownVersion, CONFIRM, confirm))); + }}); + kat.receiveConfirm(); } @Test(expected = AbortException.class) public void testReceiveConfirmThrowsExceptionIfAbortIsReceived() throws Exception { - setup(createRecord(PROTOCOL_VERSION, ABORT, new byte[0])); + setup(); + context.checking(new Expectations() {{ + oneOf(recordReader).readRecord(); + will(returnValue(new Record(PROTOCOL_VERSION, ABORT, new byte[0]))); + }}); + kat.receiveConfirm(); } @Test(expected = AbortException.class) - public void testReceiveKeyThrowsExceptionIfKeyIsReceived() + public void testReceiveConfirmThrowsExceptionIfKeyIsReceived() throws Exception { - setup(createRecord(PROTOCOL_VERSION, KEY, new byte[123])); + byte[] key = getRandomBytes(123); + + setup(); + context.checking(new Expectations() {{ + oneOf(recordReader).readRecord(); + will(returnValue(new Record(PROTOCOL_VERSION, KEY, key))); + }}); + kat.receiveConfirm(); } @Test public void testReceiveConfirmSkipsUnrecognisedRecordTypes() throws Exception { - byte[] skip1 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 1), - new byte[123]); - byte[] skip2 = createRecord(PROTOCOL_VERSION, (byte) (ABORT + 2), - new byte[0]); - byte[] payload = TestUtils.getRandomBytes(123); - byte[] confirm = createRecord(PROTOCOL_VERSION, CONFIRM, payload); - ByteArrayOutputStream input = new ByteArrayOutputStream(); - input.write(skip1); - input.write(skip2); - input.write(confirm); - setup(input.toByteArray()); - assertArrayEquals(payload, kat.receiveConfirm()); - } - - private void setup(byte[] input) throws Exception { - inputStream = new ByteArrayInputStream(input); - outputStream = new ByteArrayOutputStream(); + byte type1 = (byte) (ABORT + 1); + byte[] payload1 = getRandomBytes(123); + Record unknownRecord1 = new Record(PROTOCOL_VERSION, type1, payload1); + byte type2 = (byte) (ABORT + 2); + byte[] payload2 = new byte[0]; + Record unknownRecord2 = new Record(PROTOCOL_VERSION, type2, payload2); + byte[] confirm = getRandomBytes(123); + Record confirmRecord = new Record(PROTOCOL_VERSION, CONFIRM, confirm); + + setup(); + context.checking(new Expectations() {{ + oneOf(recordReader).readRecord(); + will(returnValue(unknownRecord1)); + oneOf(recordReader).readRecord(); + will(returnValue(unknownRecord2)); + oneOf(recordReader).readRecord(); + will(returnValue(confirmRecord)); + }}); + + assertArrayEquals(confirm, kat.receiveConfirm()); + } + + private void setup() throws Exception { context.checking(new Expectations() {{ allowing(duplexTransportConnection).getReader(); will(returnValue(transportConnectionReader)); allowing(transportConnectionReader).getInputStream(); will(returnValue(inputStream)); + oneOf(recordReaderFactory).createRecordReader(inputStream); + will(returnValue(recordReader)); allowing(duplexTransportConnection).getWriter(); will(returnValue(transportConnectionWriter)); allowing(transportConnectionWriter).getOutputStream(); will(returnValue(outputStream)); + oneOf(recordWriterFactory).createRecordWriter(outputStream); + will(returnValue(recordWriter)); + }}); + kat = new KeyAgreementTransport(recordReaderFactory, + recordWriterFactory, keyAgreementConnection); + } + + private AtomicReference<Record> expectWriteRecord() throws Exception { + AtomicReference<Record> captured = new AtomicReference<>(); + context.checking(new Expectations() {{ + oneOf(recordWriter).writeRecord(with(any(Record.class))); + will(new CaptureArgumentAction<>(captured, Record.class, 0)); + oneOf(recordWriter).flush(); }}); - kat = new KeyAgreementTransport(keyAgreementConnection); - } - - private void assertRecordSent(byte expectedType, byte[] expectedPayload) { - byte[] output = outputStream.toByteArray(); - assertEquals(RECORD_HEADER_LENGTH + expectedPayload.length, - output.length); - assertEquals(PROTOCOL_VERSION, output[0]); - assertEquals(expectedType, output[1]); - assertEquals(expectedPayload.length, ByteUtils.readUint16(output, 2)); - byte[] payload = new byte[output.length - RECORD_HEADER_LENGTH]; - System.arraycopy(output, RECORD_HEADER_LENGTH, payload, 0, - payload.length); - assertArrayEquals(expectedPayload, payload); - } - - private byte[] createRecord(byte version, byte type, byte[] payload) { - byte[] b = new byte[RECORD_HEADER_LENGTH + payload.length]; - b[0] = version; - b[1] = type; - ByteUtils.writeUint16(payload.length, b, 2); - System.arraycopy(payload, 0, b, RECORD_HEADER_LENGTH, payload.length); - return b; + return captured; + } + + private void assertRecordEquals(byte expectedVersion, byte expectedType, + byte[] expectedPayload, Record actual) { + assertEquals(expectedVersion, actual.getProtocolVersion()); + assertEquals(expectedType, actual.getRecordType()); + assertArrayEquals(expectedPayload, actual.getPayload()); } }