diff --git a/components/net/sf/briar/transport/ConnectionReaderImpl.java b/components/net/sf/briar/transport/ConnectionReaderImpl.java index 2c07351c4a2c175459f3e511a81e4c55c5aeaed2..258eb824ee7384b9248fe569405d4fef6ff1c4fd 100644 --- a/components/net/sf/briar/transport/ConnectionReaderImpl.java +++ b/components/net/sf/briar/transport/ConnectionReaderImpl.java @@ -54,7 +54,7 @@ implements ConnectionReader { @Override public int read() throws IOException { - if(betweenFrames && !readFrame()) return -1; + if(betweenFrames && !readNonEmptyFrame()) return -1; int i = payload[payloadOff]; payloadOff++; payloadLen--; @@ -69,7 +69,7 @@ implements ConnectionReader { @Override public int read(byte[] b, int off, int len) throws IOException { - if(betweenFrames && !readFrame()) return -1; + if(betweenFrames && !readNonEmptyFrame()) return -1; len = Math.min(len, payloadLen); System.arraycopy(payload, payloadOff, b, off, len); payloadOff += len; @@ -78,7 +78,15 @@ implements ConnectionReader { return len; } - private boolean readFrame() throws IOException { + private boolean readNonEmptyFrame() throws IOException { + int payload = 0; + do { + payload = readFrame(); + } while(payload == 0); + return payload > 0; + } + + private int readFrame() throws IOException { assert betweenFrames; // Don't allow more than 2^32 frames to be read if(frame > MAX_32_BIT_UNSIGNED) throw new IllegalStateException(); @@ -89,7 +97,7 @@ implements ConnectionReader { if(read == -1) break; offset += read; } - if(offset == 0) return false; // EOF between frames + if(offset == 0) return -1; // EOF between frames if(offset < header.length) throw new EOFException(); // Unexpected EOF // Check that the frame number is correct and the length is legal if(!HeaderEncoder.validateHeader(header, frame, maxPayloadLength)) @@ -122,8 +130,8 @@ implements ConnectionReader { byte[] expectedMac = mac.doFinal(); decrypter.readMac(footer); if(!Arrays.equals(expectedMac, footer)) throw new FormatException(); - betweenFrames = false; frame++; - return true; + if(payloadLen > 0) betweenFrames = false; + return payloadLen; } } diff --git a/components/net/sf/briar/transport/HeaderEncoder.java b/components/net/sf/briar/transport/HeaderEncoder.java index d8a3868b6cae317ef81bc41b9b47c547a82d31df..58bf41da1e3f454aa03bbbc6461c6ce49593ec23 100644 --- a/components/net/sf/briar/transport/HeaderEncoder.java +++ b/components/net/sf/briar/transport/HeaderEncoder.java @@ -1,13 +1,13 @@ package net.sf.briar.transport; -import net.sf.briar.api.transport.TransportConstants; +import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH; import net.sf.briar.util.ByteUtils; class HeaderEncoder { static void encodeHeader(byte[] header, long frame, int payload, int padding) { - if(header.length < TransportConstants.FRAME_HEADER_LENGTH) + if(header.length < FRAME_HEADER_LENGTH) throw new IllegalArgumentException(); if(frame < 0 || frame > ByteUtils.MAX_32_BIT_UNSIGNED) throw new IllegalArgumentException(); @@ -21,24 +21,22 @@ class HeaderEncoder { } static boolean validateHeader(byte[] header, long frame, int max) { - if(header.length < TransportConstants.FRAME_HEADER_LENGTH) - return false; + if(header.length < FRAME_HEADER_LENGTH) return false; if(ByteUtils.readUint32(header, 0) != frame) return false; int payload = ByteUtils.readUint16(header, 4); int padding = ByteUtils.readUint16(header, 6); - if(payload + padding == 0) return false; if(payload + padding > max) return false; return true; } static int getPayloadLength(byte[] header) { - if(header.length < TransportConstants.FRAME_HEADER_LENGTH) + if(header.length < FRAME_HEADER_LENGTH) throw new IllegalArgumentException(); return ByteUtils.readUint16(header, 4); } static int getPaddingLength(byte[] header) { - if(header.length < TransportConstants.FRAME_HEADER_LENGTH) + if(header.length < FRAME_HEADER_LENGTH) throw new IllegalArgumentException(); return ByteUtils.readUint16(header, 6); } diff --git a/test/net/sf/briar/transport/ConnectionReaderImplTest.java b/test/net/sf/briar/transport/ConnectionReaderImplTest.java index c5a87df8efaffc4fac1a55654d53754c716e1435..906958f8f1fbd3ef61f4571408f54927f3457ce0 100644 --- a/test/net/sf/briar/transport/ConnectionReaderImplTest.java +++ b/test/net/sf/briar/transport/ConnectionReaderImplTest.java @@ -33,10 +33,8 @@ public class ConnectionReaderImplTest extends TransportTest { ByteArrayInputStream in = new ByteArrayInputStream(frame); ConnectionDecrypter d = new NullConnectionDecrypter(in); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); - try { - r.getInputStream().read(); - fail(); - } catch(FormatException expected) {} + // There should be no bytes available before EOF + assertEquals(-1, r.getInputStream().read()); } @Test