diff --git a/components/net/sf/briar/transport/ConnectionReaderImpl.java b/components/net/sf/briar/transport/ConnectionReaderImpl.java index efebd6b301e1a1b1db18576e093e1e6a02d5279c..9fef13c9f8a292b2bb8fdb098abbd635e242bd0c 100644 --- a/components/net/sf/briar/transport/ConnectionReaderImpl.java +++ b/components/net/sf/briar/transport/ConnectionReaderImpl.java @@ -19,6 +19,7 @@ class ConnectionReaderImpl extends InputStream implements ConnectionReader { ConnectionReaderImpl(IncomingReliabilityLayer in, boolean tolerateErrors) { this.in = in; this.tolerateErrors = tolerateErrors; + frame = new Frame(in.getMaxFrameLength()); } public InputStream getInputStream() { diff --git a/components/net/sf/briar/transport/IncomingReliabilityLayerImpl.java b/components/net/sf/briar/transport/IncomingReliabilityLayerImpl.java index 5635d0313cac6ee32c25f209a8872c04f98b6013..174dee5baab01e144409d15dcda5a08dfa3efda2 100644 --- a/components/net/sf/briar/transport/IncomingReliabilityLayerImpl.java +++ b/components/net/sf/briar/transport/IncomingReliabilityLayerImpl.java @@ -5,12 +5,13 @@ import java.util.ArrayList; import java.util.LinkedList; import java.util.ListIterator; +/** A reliability layer that reorders out-of-order frames. */ class IncomingReliabilityLayerImpl implements IncomingReliabilityLayer { private final IncomingAuthenticationLayer in; private final int maxFrameLength; private final FrameWindow window; - private final LinkedList<Frame> frames; + private final LinkedList<Frame> frames; // Ordered by frame number private final ArrayList<Frame> freeFrames; private long nextFrameNumber = 0L; @@ -38,6 +39,8 @@ class IncomingReliabilityLayerImpl implements IncomingReliabilityLayer { // If the frame is in order, return it long frameNumber = f.getFrameNumber(); if(frameNumber == nextFrameNumber) { + if(!window.remove(nextFrameNumber)) + throw new IllegalStateException(); nextFrameNumber++; return f; } @@ -60,8 +63,8 @@ class IncomingReliabilityLayerImpl implements IncomingReliabilityLayer { } next = frames.peek(); } - assert next != null && next.getFrameNumber() == nextFrameNumber; frames.poll(); + if(!window.remove(nextFrameNumber)) throw new IllegalStateException(); nextFrameNumber++; return next; } @@ -69,4 +72,9 @@ class IncomingReliabilityLayerImpl implements IncomingReliabilityLayer { public int getMaxFrameLength() { return maxFrameLength; } + + // Only for testing + public int getFreeFramesCount() { + return freeFrames.size(); + } } diff --git a/test/build.xml b/test/build.xml index 8d73c651a6a796bc76d2bb482fe776533b5fdf7c..77541769ffe5425e41d192557f174ee53c8f9809 100644 --- a/test/build.xml +++ b/test/build.xml @@ -59,6 +59,7 @@ <test name='net.sf.briar.transport.FrameWindowImplTest'/> <test name='net.sf.briar.transport.IncomingEncryptionLayerImplTest'/> <test name='net.sf.briar.transport.IncomingErrorCorrectionLayerImplTest'/> + <test name='net.sf.briar.transport.IncomingReliabilityLayerImplTest'/> <test name='net.sf.briar.transport.OutgoingEncryptionLayerImplTest'/> <test name='net.sf.briar.transport.SegmentedIncomingEncryptionLayerTest'/> <test name='net.sf.briar.transport.SegmentedOutgoingEncryptionLayerTest'/> diff --git a/test/net/sf/briar/transport/IncomingReliabilityLayerImplTest.java b/test/net/sf/briar/transport/IncomingReliabilityLayerImplTest.java new file mode 100644 index 0000000000000000000000000000000000000000..2a2c11faf3eff4537af6929efeb44f6c74868aa7 --- /dev/null +++ b/test/net/sf/briar/transport/IncomingReliabilityLayerImplTest.java @@ -0,0 +1,93 @@ +package net.sf.briar.transport; + +import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH; +import static net.sf.briar.api.transport.TransportConstants.FRAME_WINDOW_SIZE; +import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH; + +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import net.sf.briar.BriarTestCase; +import net.sf.briar.api.transport.ConnectionReader; + +import org.junit.Test; + +public class IncomingReliabilityLayerImplTest extends BriarTestCase { + + @Test + public void testNoReordering() throws Exception { + List<Integer> frameNumbers = new ArrayList<Integer>(); + // Receive FRAME_WINDOW_SIZE * 2 frames in the correct order + for(int i = 0; i < FRAME_WINDOW_SIZE * 2; i++) frameNumbers.add(i); + IncomingAuthenticationLayer authentication = + new TestIncomingAuthenticationLayer(frameNumbers); + IncomingReliabilityLayerImpl reliability = + new IncomingReliabilityLayerImpl(authentication); + ConnectionReader reader = new ConnectionReaderImpl(reliability, false); + InputStream in = reader.getInputStream(); + for(int i = 0; i < FRAME_WINDOW_SIZE * 2; i++) { + for(int j = 0; j < 100; j++) assertEquals(i, in.read()); + } + assertEquals(-1, in.read()); + // No free frames should be cached + assertEquals(0, reliability.getFreeFramesCount()); + } + + @Test + public void testReordering() throws Exception { + List<Integer> frameNumbers = new ArrayList<Integer>(); + // Receive the first FRAME_WINDOW_SIZE frames in a random order + for(int i = 0; i < FRAME_WINDOW_SIZE; i++) frameNumbers.add(i); + Collections.shuffle(frameNumbers); + // Receive the next FRAME_WINDOW_SIZE frames in the correct order + for(int i = FRAME_WINDOW_SIZE; i < FRAME_WINDOW_SIZE * 2; i++) { + frameNumbers.add(i); + } + // The reliability layer should reorder the frames + IncomingAuthenticationLayer authentication = + new TestIncomingAuthenticationLayer(frameNumbers); + IncomingReliabilityLayerImpl reliability = + new IncomingReliabilityLayerImpl(authentication); + ConnectionReader reader = new ConnectionReaderImpl(reliability, false); + InputStream in = reader.getInputStream(); + for(int i = 0; i < FRAME_WINDOW_SIZE * 2; i++) { + for(int j = 0; j < 100; j++) assertEquals(i, in.read()); + } + assertEquals(-1, in.read()); + // Fewer than FRAME_WINDOW_SIZE free frames should be cached + assertTrue(reliability.getFreeFramesCount() < 32); + } + + private static class TestIncomingAuthenticationLayer + implements IncomingAuthenticationLayer { + + private final List<Integer> frameNumbers; + + private int index; + + private TestIncomingAuthenticationLayer(List<Integer> frameNumbers) { + this.frameNumbers = frameNumbers; + index = 0; + } + + public boolean readFrame(Frame f, FrameWindow window) { + if(index >= frameNumbers.size()) return false; + int frameNumber = frameNumbers.get(index); + assertTrue(window.contains(frameNumber)); + index++; + byte[] buf = f.getBuffer(); + HeaderEncoder.encodeHeader(buf, frameNumber, 100, 0); + for(int i = 0; i < 100; i++) { + buf[FRAME_HEADER_LENGTH + i] = (byte) frameNumber; + } + f.setLength(FRAME_HEADER_LENGTH + 100 + MAC_LENGTH); + return true; + } + + public int getMaxFrameLength() { + return FRAME_HEADER_LENGTH + 100 + MAC_LENGTH; + } + } +}