diff --git a/test/net/sf/briar/transport/ConnectionReaderImplTest.java b/test/net/sf/briar/transport/ConnectionReaderImplTest.java index 1e8fe8a5681b39b142731de4508ea6f0b370dcf8..b0ce60435921e5231aef8121741c63d502e8320d 100644 --- a/test/net/sf/briar/transport/ConnectionReaderImplTest.java +++ b/test/net/sf/briar/transport/ConnectionReaderImplTest.java @@ -1,5 +1,7 @@ package net.sf.briar.transport; +import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH; + import java.io.ByteArrayInputStream; import java.io.EOFException; import java.io.IOException; @@ -10,6 +12,7 @@ import javax.crypto.Mac; import junit.framework.TestCase; import net.sf.briar.TestUtils; +import net.sf.briar.api.FormatException; import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.crypto.CryptoModule; @@ -33,13 +36,28 @@ public class ConnectionReaderImplTest extends TestCase { mac.init(crypto.generateSecretKey()); } - // FIXME: Test corner cases and corrupt frames + @Test + public void testLengthZero() throws Exception { + // Six bytes for the header, none for the payload + byte[] frame = new byte[6 + mac.getMacLength()]; + // Calculate the MAC + mac.update(frame, 0, 6); + mac.doFinal(frame, 6); + // Read the frame + ByteArrayInputStream in = new ByteArrayInputStream(frame); + ConnectionDecrypter d = new NullConnectionDecrypter(in); + ConnectionReader r = new ConnectionReaderImpl(d, mac); + try { + r.getInputStream().read(); + fail(); + } catch(FormatException expected) {} + } @Test - public void testSingleByteFrame() throws Exception { + public void testLengthOne() throws Exception { // Six bytes for the header, one for the payload byte[] frame = new byte[6 + 1 + mac.getMacLength()]; - ByteUtils.writeUint16(1, frame, 4); // Payload length = 1 + ByteUtils.writeUint16(1, frame, 4); // Frame number 0, length 1 // Calculate the MAC mac.update(frame, 0, 6 + 1); mac.doFinal(frame, 6 + 1); @@ -52,17 +70,49 @@ public class ConnectionReaderImplTest extends TestCase { assertEquals(-1, r.getInputStream().read()); } + @Test + public void testMaxLength() throws Exception { + int maxPayloadLength = MAX_FRAME_LENGTH - 6 - mac.getMacLength(); + // First frame: max payload length + byte[] frame = new byte[6 + maxPayloadLength + mac.getMacLength()]; + ByteUtils.writeUint16(maxPayloadLength, frame, 4); + mac.update(frame, 0, 6 + maxPayloadLength); + mac.doFinal(frame, 6 + maxPayloadLength); + // Second frame: max payload length plus one + byte[] frame1 = new byte[6 + maxPayloadLength + 1 + mac.getMacLength()]; + ByteUtils.writeUint32(1, frame1, 0); + ByteUtils.writeUint16(maxPayloadLength + 1, frame1, 4); + mac.update(frame1, 0, 6 + maxPayloadLength + 1); + mac.doFinal(frame1, 6 + maxPayloadLength + 1); + // Concatenate the frames + ByteArrayOutputStream out = new ByteArrayOutputStream(); + out.write(frame); + out.write(frame1); + // Read the first frame + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + ConnectionDecrypter d = new NullConnectionDecrypter(in); + ConnectionReader r = new ConnectionReaderImpl(d, mac); + byte[] read = new byte[maxPayloadLength]; + TestUtils.readFully(r.getInputStream(), read); + // Try to read the second frame + byte[] read1 = new byte[maxPayloadLength + 1]; + try { + TestUtils.readFully(r.getInputStream(), read1); + fail(); + } catch(FormatException expected) {} + } + @Test public void testMultipleFrames() throws Exception { // First frame: 123-byte payload byte[] frame = new byte[6 + 123 + mac.getMacLength()]; - ByteUtils.writeUint16(123, frame, 4); + ByteUtils.writeUint16(123, frame, 4); // Frame number 0, length 123 mac.update(frame, 0, 6 + 123); mac.doFinal(frame, 6 + 123); // Second frame: 1234-byte payload byte[] frame1 = new byte[6 + 1234 + mac.getMacLength()]; - ByteUtils.writeUint32(1, frame1, 0); - ByteUtils.writeUint16(1234, frame1, 4); + ByteUtils.writeUint32(1, frame1, 0); // Frame number 1 + ByteUtils.writeUint16(1234, frame1, 4); // Length 1234 mac.update(frame1, 0, 6 + 1234); mac.doFinal(frame1, 6 + 1234); // Concatenate the frames @@ -81,6 +131,46 @@ public class ConnectionReaderImplTest extends TestCase { assertTrue(Arrays.equals(new byte[1234], read1)); } + @Test + public void testCorruptPayload() throws Exception { + // Six bytes for the header, eight for the payload + byte[] frame = new byte[6 + 8 + mac.getMacLength()]; + ByteUtils.writeUint16(8, frame, 4); // Frame number 0, length 8 + // Calculate the MAC + mac.update(frame, 0, 6 + 8); + mac.doFinal(frame, 6 + 8); + // Modify the payload + frame[12] ^= 1; + // Try to read the frame - not a single byte should be read + ByteArrayInputStream in = new ByteArrayInputStream(frame); + ConnectionDecrypter d = new NullConnectionDecrypter(in); + ConnectionReader r = new ConnectionReaderImpl(d, mac); + try { + r.getInputStream().read(); + fail(); + } catch(FormatException expected) {} + } + + @Test + public void testCorruptMac() throws Exception { + // Six bytes for the header, eight for the payload + byte[] frame = new byte[6 + 8 + mac.getMacLength()]; + ByteUtils.writeUint16(8, frame, 4); // Frame number 0, length 8 + // Calculate the MAC + mac.update(frame, 0, 6 + 8); + mac.doFinal(frame, 6 + 8); + // Modify the MAC + frame[17] ^= 1; + // Try to read the frame - not a single byte should be read + ByteArrayInputStream in = new ByteArrayInputStream(frame); + ConnectionDecrypter d = new NullConnectionDecrypter(in); + ConnectionReader r = new ConnectionReaderImpl(d, mac); + try { + r.getInputStream().read(); + fail(); + } catch(FormatException expected) {} + } + /** A ConnectionDecrypter that performs no decryption. */ private static class NullConnectionDecrypter implements ConnectionDecrypter {