diff --git a/components/net/sf/briar/transport/ConnectionDecrypter.java b/components/net/sf/briar/transport/ConnectionDecrypter.java index b4763679b63427e08bf1b35ac5ac60a17025a8ba..e872f242fecccfdb61c60e8f3063f60ec6d88e40 100644 --- a/components/net/sf/briar/transport/ConnectionDecrypter.java +++ b/components/net/sf/briar/transport/ConnectionDecrypter.java @@ -1,14 +1,13 @@ package net.sf.briar.transport; import java.io.IOException; -import java.io.InputStream; /** Decrypts unauthenticated data received over a connection. */ interface ConnectionDecrypter { - /** Returns an input stream from which decrypted data can be read. */ - InputStream getInputStream(); - - /** Reads and decrypts the remainder of the current frame. */ - void readFinal(byte[] b) throws IOException; + /** + * Reads and decrypts a frame into the given buffer and returns the length + * of the decrypted frame, or -1 if no more frames can be read. + */ + int readFrame(byte[] b) throws IOException; } diff --git a/components/net/sf/briar/transport/ConnectionDecrypterImpl.java b/components/net/sf/briar/transport/ConnectionDecrypterImpl.java index ff354a57b6c36a1d216a523315049452838b046d..b4d93232cefa587b673ad7302a2fd76057b53c2d 100644 --- a/components/net/sf/briar/transport/ConnectionDecrypterImpl.java +++ b/components/net/sf/briar/transport/ConnectionDecrypterImpl.java @@ -1,9 +1,10 @@ package net.sf.briar.transport; +import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH; +import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH; import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED; import java.io.EOFException; -import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; import java.security.GeneralSecurityException; @@ -11,144 +12,88 @@ import java.security.GeneralSecurityException; import javax.crypto.Cipher; import javax.crypto.spec.IvParameterSpec; +import net.sf.briar.api.FormatException; import net.sf.briar.api.crypto.ErasableKey; -class ConnectionDecrypterImpl extends FilterInputStream -implements ConnectionDecrypter { +class ConnectionDecrypterImpl implements ConnectionDecrypter { + private final InputStream in; private final Cipher frameCipher; private final ErasableKey frameKey; - private final byte[] iv, buf; + private final int macLength, blockSize; + private final byte[] iv; - private int bufOff = 0, bufLen = 0; private long frame = 0L; - private boolean betweenFrames = true; ConnectionDecrypterImpl(InputStream in, Cipher frameCipher, - ErasableKey frameKey) { - super(in); + ErasableKey frameKey, int macLength) { + this.in = in; this.frameCipher = frameCipher; this.frameKey = frameKey; - iv = IvEncoder.encodeIv(0, frameCipher.getBlockSize()); - buf = new byte[frameCipher.getBlockSize()]; + this.macLength = macLength; + blockSize = frameCipher.getBlockSize(); + if(blockSize < FRAME_HEADER_LENGTH) + throw new IllegalArgumentException(); + iv = IvEncoder.encodeIv(0, blockSize); } - public InputStream getInputStream() { - return this; - } - - public void readFinal(byte[] b) throws IOException { + public int readFrame(byte[] b) throws IOException { + if(b.length < MAX_FRAME_LENGTH) throw new IllegalArgumentException(); + if(frame > MAX_32_BIT_UNSIGNED) throw new IllegalStateException(); + // Initialise the cipher + IvEncoder.updateIv(iv, frame); + IvParameterSpec ivSpec = new IvParameterSpec(iv); try { - if(betweenFrames) throw new IllegalStateException(); - // If we have any plaintext in the buffer, copy it into the frame - System.arraycopy(buf, bufOff, b, 0, bufLen); - // Read the remainder of the frame - int offset = bufLen; - while(offset < b.length) { - int read = in.read(b, offset, b.length - offset); - if(read == -1) break; + frameCipher.init(Cipher.DECRYPT_MODE, frameKey, ivSpec); + } catch(GeneralSecurityException badIvOrKey) { + throw new RuntimeException(badIvOrKey); + } + try { + // Read the first block + int offset = 0; + while(offset < blockSize) { + int read = in.read(b, offset, blockSize - offset); + if(read == -1) { + if(offset == 0) return -1; + if(offset < blockSize) throw new EOFException(); + break; + } offset += read; } - if(offset < b.length) throw new EOFException(); // Unexpected EOF - // Decrypt the remainder of the frame + // Decrypt the first block try { - int length = b.length - bufLen; - int i = frameCipher.doFinal(b, bufLen, length, b, bufLen); - if(i < length) throw new RuntimeException(); + int decrypted = frameCipher.update(b, 0, blockSize, b); + assert decrypted == blockSize; } catch(GeneralSecurityException badCipher) { throw new RuntimeException(badCipher); } - bufOff = bufLen = 0; - betweenFrames = true; - } catch(IOException e) { - frameKey.erase(); - throw e; - } - } - - @Override - public int read() throws IOException { - try { - if(betweenFrames) initialiseCipher(); - if(bufLen == 0) { - if(!readBlock()) { - frameKey.erase(); - return -1; - } - bufOff = 0; - bufLen = buf.length; + // Validate and parse the header + int max = MAX_FRAME_LENGTH - FRAME_HEADER_LENGTH - macLength; + if(!HeaderEncoder.validateHeader(b, frame, max)) + throw new FormatException(); + int payload = HeaderEncoder.getPayloadLength(b); + int padding = HeaderEncoder.getPaddingLength(b); + int length = FRAME_HEADER_LENGTH + payload + padding + macLength; + if(length > MAX_FRAME_LENGTH) throw new FormatException(); + // Read the remainder of the frame + while(offset < length) { + int read = in.read(b, offset, length - offset); + if(read == -1) throw new EOFException(); + offset += read; } - int i = buf[bufOff]; - bufOff++; - bufLen--; - return i < 0 ? i + 256 : i; - } catch(IOException e) { - frameKey.erase(); - throw e; - } - } - - @Override - public int read(byte[] b) throws IOException { - return read(b, 0, b.length); - } - - @Override - public int read(byte[] b, int off, int len) throws IOException { - try { - if(betweenFrames) initialiseCipher(); - if(bufLen == 0) { - if(!readBlock()) { - frameKey.erase(); - return -1; - } - bufOff = 0; - bufLen = buf.length; + // Decrypt the remainder of the frame + try { + int decrypted = frameCipher.doFinal(b, blockSize, + length - blockSize, b, blockSize); + assert decrypted == length - blockSize; + } catch(GeneralSecurityException badCipher) { + throw new RuntimeException(badCipher); } - int length = Math.min(len, bufLen); - System.arraycopy(buf, bufOff, b, off, length); - bufOff += length; - bufLen -= length; + frame++; return length; } catch(IOException e) { frameKey.erase(); throw e; } } - - // Although we're using CTR mode, which doesn't require full blocks of - // ciphertext, the cipher still tries to operate a block at a time - private boolean readBlock() throws IOException { - // Try to read a block of ciphertext - int offset = 0; - while(offset < buf.length) { - int read = in.read(buf, offset, buf.length - offset); - if(read == -1) break; - offset += read; - } - if(offset == 0) return false; - if(offset < buf.length) throw new EOFException(); // Unexpected EOF - // Decrypt the block - try { - int i = frameCipher.update(buf, 0, offset, buf); - if(i < offset) throw new RuntimeException(); - } catch(GeneralSecurityException badCipher) { - throw new RuntimeException(badCipher); - } - return true; - } - - private void initialiseCipher() { - assert betweenFrames; - if(frame > MAX_32_BIT_UNSIGNED) throw new IllegalStateException(); - IvEncoder.updateIv(iv, frame); - IvParameterSpec ivSpec = new IvParameterSpec(iv); - try { - frameCipher.init(Cipher.DECRYPT_MODE, frameKey, ivSpec); - } catch(GeneralSecurityException badIvOrKey) { - throw new RuntimeException(badIvOrKey); - } - frame++; - betweenFrames = false; - } } \ No newline at end of file diff --git a/components/net/sf/briar/transport/ConnectionEncrypter.java b/components/net/sf/briar/transport/ConnectionEncrypter.java index 7d4fdb94252d5ef8c8827e432b42410ece355488..11cf16e7c54c413448ca854ad7c1547a1cbc1141 100644 --- a/components/net/sf/briar/transport/ConnectionEncrypter.java +++ b/components/net/sf/briar/transport/ConnectionEncrypter.java @@ -6,7 +6,7 @@ import java.io.IOException; interface ConnectionEncrypter { /** Encrypts and writes the given frame. */ - void writeFrame(byte[] b, int off, int len) throws IOException; + void writeFrame(byte[] b, int len) throws IOException; /** Flushes the output stream. */ void flush() throws IOException; diff --git a/components/net/sf/briar/transport/ConnectionEncrypterImpl.java b/components/net/sf/briar/transport/ConnectionEncrypterImpl.java index 24c66e35159e855011507e5ed12d2a4073776e63..e7e161652414c6d76b2c5cddf4ed159bc569bc68 100644 --- a/components/net/sf/briar/transport/ConnectionEncrypterImpl.java +++ b/components/net/sf/briar/transport/ConnectionEncrypterImpl.java @@ -35,7 +35,7 @@ class ConnectionEncrypterImpl implements ConnectionEncrypter { if(tag.length != TAG_LENGTH) throw new IllegalArgumentException(); } - public void writeFrame(byte[] b, int off, int len) throws IOException { + public void writeFrame(byte[] b, int len) throws IOException { try { if(!tagWritten) { out.write(tag); @@ -47,12 +47,12 @@ class ConnectionEncrypterImpl implements ConnectionEncrypter { IvParameterSpec ivSpec = new IvParameterSpec(iv); try { frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec); - int encrypted = frameCipher.doFinal(b, off, len, b, off); + int encrypted = frameCipher.doFinal(b, 0, len, b, 0); assert encrypted == len; } catch(GeneralSecurityException badCipher) { throw new RuntimeException(badCipher); } - out.write(b, off, len); + out.write(b, 0, len); capacity -= len; frame++; } catch(IOException e) { diff --git a/components/net/sf/briar/transport/ConnectionReaderFactoryImpl.java b/components/net/sf/briar/transport/ConnectionReaderFactoryImpl.java index d73a45931cbcb32bb239174f2794d3014a850182..a80769c984d9a7ad8edd78f20b3815fa4b6fb2cd 100644 --- a/components/net/sf/briar/transport/ConnectionReaderFactoryImpl.java +++ b/components/net/sf/briar/transport/ConnectionReaderFactoryImpl.java @@ -46,10 +46,10 @@ class ConnectionReaderFactoryImpl implements ConnectionReaderFactory { ByteUtils.erase(secret); // Create the decrypter Cipher frameCipher = crypto.getFrameCipher(); + Mac mac = crypto.getMac(); ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in, - frameCipher, frameKey); + frameCipher, frameKey, mac.getMacLength()); // Create the reader - Mac mac = crypto.getMac(); return new ConnectionReaderImpl(decrypter, mac, macKey); } } diff --git a/components/net/sf/briar/transport/ConnectionReaderImpl.java b/components/net/sf/briar/transport/ConnectionReaderImpl.java index 880119052d23d6a649d692ece22f9a58cf0d9b4c..4aef95b3fc909c6922e8bb94d919a39360b023a2 100644 --- a/components/net/sf/briar/transport/ConnectionReaderImpl.java +++ b/components/net/sf/briar/transport/ConnectionReaderImpl.java @@ -4,12 +4,9 @@ import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH; import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH; import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED; -import java.io.EOFException; -import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; import java.security.InvalidKeyException; -import java.util.Arrays; import javax.crypto.Mac; @@ -17,21 +14,18 @@ import net.sf.briar.api.FormatException; import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.api.transport.ConnectionReader; -class ConnectionReaderImpl extends FilterInputStream -implements ConnectionReader { +class ConnectionReaderImpl extends InputStream implements ConnectionReader { private final ConnectionDecrypter decrypter; private final Mac mac; - private final int maxPayloadLength; - private final byte[] header, payload, footer; + private final int macLength; + private final byte[] buf; private long frame = 0L; - private int payloadOff = 0, payloadLen = 0; - private boolean betweenFrames = true; + private int bufOffset = 0, bufLength = 0; ConnectionReaderImpl(ConnectionDecrypter decrypter, Mac mac, ErasableKey macKey) { - super(decrypter.getInputStream()); this.decrypter = decrypter; this.mac = mac; // Initialise the MAC @@ -41,11 +35,8 @@ implements ConnectionReader { throw new IllegalArgumentException(e); } macKey.erase(); - maxPayloadLength = - MAX_FRAME_LENGTH - FRAME_HEADER_LENGTH - mac.getMacLength(); - header = new byte[FRAME_HEADER_LENGTH]; - payload = new byte[maxPayloadLength]; - footer = new byte[mac.getMacLength()]; + macLength = mac.getMacLength(); + buf = new byte[MAX_FRAME_LENGTH]; } public InputStream getInputStream() { @@ -54,12 +45,11 @@ implements ConnectionReader { @Override public int read() throws IOException { - if(betweenFrames && !readNonEmptyFrame()) return -1; - int i = payload[payloadOff]; - payloadOff++; - payloadLen--; - if(payloadLen == 0) betweenFrames = true; - return i; + while(bufLength == 0) if(!readFrame()) return -1; + int b = buf[bufOffset] & 0xff; + bufOffset++; + bufLength--; + return b; } @Override @@ -69,69 +59,44 @@ implements ConnectionReader { @Override public int read(byte[] b, int off, int len) throws IOException { - if(betweenFrames && !readNonEmptyFrame()) return -1; - len = Math.min(len, payloadLen); - System.arraycopy(payload, payloadOff, b, off, len); - payloadOff += len; - payloadLen -= len; - if(payloadLen == 0) betweenFrames = true; + while(bufLength == 0) if(!readFrame()) return -1; + len = Math.min(len, bufLength); + System.arraycopy(buf, bufOffset, b, off, len); + bufOffset += len; + bufLength -= len; return len; } - private boolean readNonEmptyFrame() throws IOException { - int payload = 0; - do { - payload = readFrame(); - } while(payload == 0); - return payload > 0; - } - - private int readFrame() throws IOException { - assert betweenFrames; + private boolean readFrame() throws IOException { + assert bufLength == 0; // Don't allow more than 2^32 frames to be read if(frame > MAX_32_BIT_UNSIGNED) throw new IllegalStateException(); - // Read the header - int offset = 0; - while(offset < header.length) { - int read = in.read(header, offset, header.length - offset); - if(read == -1) break; - offset += read; - } - if(offset == 0) return -1; // EOF between frames - if(offset < header.length) throw new EOFException(); // Unexpected EOF + // Read a frame + int length = decrypter.readFrame(buf); + if(length == -1) return false; // Check that the frame number is correct and the length is legal - if(!HeaderEncoder.validateHeader(header, frame, maxPayloadLength)) + int max = MAX_FRAME_LENGTH - FRAME_HEADER_LENGTH - macLength; + if(!HeaderEncoder.validateHeader(buf, frame, max)) + throw new FormatException(); + int payload = HeaderEncoder.getPayloadLength(buf); + int padding = HeaderEncoder.getPaddingLength(buf); + if(length != FRAME_HEADER_LENGTH + payload + padding + macLength) throw new FormatException(); - payloadLen = HeaderEncoder.getPayloadLength(header); - int paddingLen = HeaderEncoder.getPaddingLength(header); - mac.update(header); - // Read the payload - offset = 0; - while(offset < payloadLen) { - int read = in.read(payload, offset, payloadLen - offset); - if(read == -1) throw new EOFException(); // Unexpected EOF - mac.update(payload, offset, read); - offset += read; - } - payloadOff = 0; - // Read the padding - while(offset < payloadLen + paddingLen) { - int read = in.read(payload, offset, - payloadLen + paddingLen - offset); - if(read == -1) throw new EOFException(); // Unexpected EOF - mac.update(payload, offset, read); - offset += read; - } // Check that the padding is all zeroes - for(int i = payloadLen; i < payloadLen + paddingLen; i++) { - if(payload[i] != 0) throw new FormatException(); + int paddingStart = FRAME_HEADER_LENGTH + payload; + for(int i = paddingStart; i < paddingStart + padding; i++) { + if(buf[i] != 0) throw new FormatException(); } - // Read the MAC + // Check the MAC + int macStart = FRAME_HEADER_LENGTH + payload + padding; + mac.update(buf, 0, macStart); byte[] expectedMac = mac.doFinal(); - decrypter.readFinal(footer); - if(!Arrays.equals(expectedMac, footer)) throw new FormatException(); + for(int i = 0; i < macLength; i++) { + if(expectedMac[i] != buf[macStart + i]) throw new FormatException(); + } + bufOffset = FRAME_HEADER_LENGTH; + bufLength = payload; frame++; - if(payloadLen > 0) betweenFrames = false; - return payloadLen; + return true; } } diff --git a/components/net/sf/briar/transport/ConnectionWriterImpl.java b/components/net/sf/briar/transport/ConnectionWriterImpl.java index eb617a2e6665c493f21a2ebc10d4d296e04fcda7..fd823dd0fd362b61baddf112d98cc4c87a720914 100644 --- a/components/net/sf/briar/transport/ConnectionWriterImpl.java +++ b/components/net/sf/briar/transport/ConnectionWriterImpl.java @@ -101,7 +101,7 @@ class ConnectionWriterImpl extends OutputStream implements ConnectionWriter { } catch(ShortBufferException badMac) { throw new RuntimeException(badMac); } - encrypter.writeFrame(buf, 0, bufLength + mac.getMacLength()); + encrypter.writeFrame(buf, bufLength + mac.getMacLength()); bufLength = FRAME_HEADER_LENGTH; frame++; } diff --git a/test/net/sf/briar/transport/ConnectionDecrypterImplTest.java b/test/net/sf/briar/transport/ConnectionDecrypterImplTest.java index 524fd37146c12f66a5100c5667fb1358ae26d288..e02447267a6d46a23ccad940d6eb7f051bbaf29a 100644 --- a/test/net/sf/briar/transport/ConnectionDecrypterImplTest.java +++ b/test/net/sf/briar/transport/ConnectionDecrypterImplTest.java @@ -1,6 +1,7 @@ package net.sf.briar.transport; -import static org.junit.Assert.assertArrayEquals; +import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH; +import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH; import java.io.ByteArrayInputStream; @@ -8,7 +9,6 @@ import javax.crypto.Cipher; import javax.crypto.spec.IvParameterSpec; import net.sf.briar.BriarTestCase; -import net.sf.briar.TestUtils; import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.ErasableKey; import net.sf.briar.crypto.CryptoModule; @@ -45,58 +45,40 @@ public class ConnectionDecrypterImplTest extends BriarTestCase { } private void testDecryption(boolean initiator) throws Exception { - // Calculate the expected plaintext for the first frame - byte[] iv = new byte[frameCipher.getBlockSize()]; - byte[] ciphertext = new byte[123]; - byte[] ciphertextMac = new byte[MAC_LENGTH]; + // Calculate the ciphertext for the first frame + byte[] plaintext = new byte[FRAME_HEADER_LENGTH + 123 + MAC_LENGTH]; + HeaderEncoder.encodeHeader(plaintext, 0L, 123, 0); + byte[] iv = IvEncoder.encodeIv(0L, frameCipher.getBlockSize()); IvParameterSpec ivSpec = new IvParameterSpec(iv); - frameCipher.init(Cipher.DECRYPT_MODE, frameKey, ivSpec); - byte[] plaintext = new byte[ciphertext.length + ciphertextMac.length]; - int offset = frameCipher.update(ciphertext, 0, ciphertext.length, - plaintext); - frameCipher.doFinal(ciphertextMac, 0, ciphertextMac.length, plaintext, - offset); - // Calculate the expected plaintext for the second frame - byte[] ciphertext1 = new byte[1234]; + frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec); + byte[] ciphertext = new byte[plaintext.length]; + frameCipher.doFinal(plaintext, 0, plaintext.length, ciphertext); + // Calculate the ciphertext for the second frame + byte[] plaintext1 = new byte[FRAME_HEADER_LENGTH + 1234 + MAC_LENGTH]; + HeaderEncoder.encodeHeader(plaintext1, 1L, 1234, 0); IvEncoder.updateIv(iv, 1L); ivSpec = new IvParameterSpec(iv); - frameCipher.init(Cipher.DECRYPT_MODE, frameKey, ivSpec); - byte[] plaintext1 = new byte[ciphertext1.length + ciphertextMac.length]; - offset = frameCipher.update(ciphertext1, 0, ciphertext1.length, - plaintext1); - frameCipher.doFinal(ciphertextMac, 0, ciphertextMac.length, plaintext1, - offset); + frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec); + byte[] ciphertext1 = new byte[plaintext1.length]; + frameCipher.doFinal(plaintext1, 0, plaintext1.length, ciphertext1); // Concatenate the ciphertexts ByteArrayOutputStream out = new ByteArrayOutputStream(); out.write(ciphertext); - out.write(ciphertextMac); out.write(ciphertext1); - out.write(ciphertextMac); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); // Use a ConnectionDecrypter to decrypt the ciphertext ConnectionDecrypter d = new ConnectionDecrypterImpl(in, frameCipher, - frameKey); + frameKey, MAC_LENGTH); // First frame - byte[] decrypted = new byte[ciphertext.length]; - TestUtils.readFully(d.getInputStream(), decrypted); - byte[] decryptedMac = new byte[MAC_LENGTH]; - d.readFinal(decryptedMac); + byte[] decrypted = new byte[MAX_FRAME_LENGTH]; + assertEquals(plaintext.length, d.readFrame(decrypted)); + for(int i = 0; i < plaintext.length; i++) { + assertEquals(plaintext[i], decrypted[i]); + } // Second frame - byte[] decrypted1 = new byte[ciphertext1.length]; - TestUtils.readFully(d.getInputStream(), decrypted1); - byte[] decryptedMac1 = new byte[MAC_LENGTH]; - d.readFinal(decryptedMac1); - // Check that the actual plaintext matches the expected plaintext - out.reset(); - out.write(plaintext); - out.write(plaintext1); - byte[] expected = out.toByteArray(); - out.reset(); - out.write(decrypted); - out.write(decryptedMac); - out.write(decrypted1); - out.write(decryptedMac1); - byte[] actual = out.toByteArray(); - assertArrayEquals(expected, actual); + assertEquals(plaintext1.length, d.readFrame(decrypted)); + for(int i = 0; i < plaintext1.length; i++) { + assertEquals(plaintext1[i], decrypted[i]); + } } } diff --git a/test/net/sf/briar/transport/ConnectionEncrypterImplTest.java b/test/net/sf/briar/transport/ConnectionEncrypterImplTest.java index 7bb2996cfaef95b13d0dc3e3993183668b345dd4..9b25a30f217a1955564877fa4dbda187b5d005bd 100644 --- a/test/net/sf/briar/transport/ConnectionEncrypterImplTest.java +++ b/test/net/sf/briar/transport/ConnectionEncrypterImplTest.java @@ -69,8 +69,8 @@ public class ConnectionEncrypterImplTest extends BriarTestCase { out.reset(); ConnectionEncrypter e = new ConnectionEncrypterImpl(out, Long.MAX_VALUE, tagCipher, frameCipher, tagKey, frameKey); - e.writeFrame(plaintext, 0, plaintext.length); - e.writeFrame(plaintext1, 0, plaintext1.length); + e.writeFrame(plaintext, plaintext.length); + e.writeFrame(plaintext1, plaintext1.length); byte[] actual = out.toByteArray(); // Check that the actual ciphertext matches the expected ciphertext assertArrayEquals(expected, actual); diff --git a/test/net/sf/briar/transport/ConnectionReaderImplTest.java b/test/net/sf/briar/transport/ConnectionReaderImplTest.java index 56caddb96938971297f7fa22644646a0ce2a6b2d..0711f4f21727851ec74c514f511828df209097db 100644 --- a/test/net/sf/briar/transport/ConnectionReaderImplTest.java +++ b/test/net/sf/briar/transport/ConnectionReaderImplTest.java @@ -31,7 +31,7 @@ public class ConnectionReaderImplTest extends TransportTest { mac.doFinal(frame, FRAME_HEADER_LENGTH + payloadLength); // Read the frame ByteArrayInputStream in = new ByteArrayInputStream(frame); - ConnectionDecrypter d = new NullConnectionDecrypter(in); + ConnectionDecrypter d = new NullConnectionDecrypter(in, macLength); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); // There should be no bytes available before EOF assertEquals(-1, r.getInputStream().read()); @@ -49,7 +49,7 @@ public class ConnectionReaderImplTest extends TransportTest { mac.doFinal(frame, FRAME_HEADER_LENGTH + payloadLength); // Read the frame ByteArrayInputStream in = new ByteArrayInputStream(frame); - ConnectionDecrypter d = new NullConnectionDecrypter(in); + ConnectionDecrypter d = new NullConnectionDecrypter(in, macLength); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); // There should be one byte available before EOF assertEquals(0, r.getInputStream().read()); @@ -75,7 +75,7 @@ public class ConnectionReaderImplTest extends TransportTest { out.write(frame1); // Read the first frame ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); - ConnectionDecrypter d = new NullConnectionDecrypter(in); + ConnectionDecrypter d = new NullConnectionDecrypter(in, macLength); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); byte[] read = new byte[maxPayloadLength]; TestUtils.readFully(r.getInputStream(), read); @@ -109,7 +109,7 @@ public class ConnectionReaderImplTest extends TransportTest { out.write(frame1); // Read the first frame ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); - ConnectionDecrypter d = new NullConnectionDecrypter(in); + ConnectionDecrypter d = new NullConnectionDecrypter(in, macLength); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); byte[] read = new byte[maxPayloadLength - paddingLength]; TestUtils.readFully(r.getInputStream(), read); @@ -135,7 +135,7 @@ public class ConnectionReaderImplTest extends TransportTest { mac.doFinal(frame, FRAME_HEADER_LENGTH + payloadLength + paddingLength); // Read the frame ByteArrayInputStream in = new ByteArrayInputStream(frame); - ConnectionDecrypter d = new NullConnectionDecrypter(in); + ConnectionDecrypter d = new NullConnectionDecrypter(in, macLength); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); // The non-zero padding should be rejected try { @@ -167,7 +167,7 @@ public class ConnectionReaderImplTest extends TransportTest { out.write(frame1); // Read the frames ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); - ConnectionDecrypter d = new NullConnectionDecrypter(in); + ConnectionDecrypter d = new NullConnectionDecrypter(in, macLength); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); byte[] read = new byte[payloadLength]; TestUtils.readFully(r.getInputStream(), read); @@ -191,7 +191,7 @@ public class ConnectionReaderImplTest extends TransportTest { frame[12] ^= 1; // Try to read the frame - not a single byte should be read ByteArrayInputStream in = new ByteArrayInputStream(frame); - ConnectionDecrypter d = new NullConnectionDecrypter(in); + ConnectionDecrypter d = new NullConnectionDecrypter(in, macLength); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); try { r.getInputStream().read(); @@ -213,7 +213,7 @@ public class ConnectionReaderImplTest extends TransportTest { frame[17] ^= 1; // Try to read the frame - not a single byte should be read ByteArrayInputStream in = new ByteArrayInputStream(frame); - ConnectionDecrypter d = new NullConnectionDecrypter(in); + ConnectionDecrypter d = new NullConnectionDecrypter(in, macLength); ConnectionReader r = new ConnectionReaderImpl(d, mac, macKey); try { r.getInputStream().read(); diff --git a/test/net/sf/briar/transport/FrameReadWriteTest.java b/test/net/sf/briar/transport/FrameReadWriteTest.java index 7c374c9544eb2a31b5ec2a5383a5f2bd1ee8f74a..b59b3dd144db60420d31bef8b42f11b6fddbd7fb 100644 --- a/test/net/sf/briar/transport/FrameReadWriteTest.java +++ b/test/net/sf/briar/transport/FrameReadWriteTest.java @@ -90,7 +90,7 @@ public class FrameReadWriteTest extends BriarTestCase { assertTrue(TagEncoder.validateTag(tag, 0, tagCipher, tagKey)); // Read the frames back ConnectionDecrypter decrypter = new ConnectionDecrypterImpl(in, - frameCipher, frameKey); + frameCipher, frameKey, mac.getMacLength()); ConnectionReader reader = new ConnectionReaderImpl(decrypter, mac, macKey); InputStream in1 = reader.getInputStream(); diff --git a/test/net/sf/briar/transport/NullConnectionDecrypter.java b/test/net/sf/briar/transport/NullConnectionDecrypter.java index bfeb8b877ab4dfeb4c5f3e4cceff36ab2a67c1a5..68d378469e437127dcf4ed602eebdc4c83408e4e 100644 --- a/test/net/sf/briar/transport/NullConnectionDecrypter.java +++ b/test/net/sf/briar/transport/NullConnectionDecrypter.java @@ -1,29 +1,48 @@ package net.sf.briar.transport; +import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH; +import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH; + import java.io.EOFException; import java.io.IOException; import java.io.InputStream; +import net.sf.briar.api.FormatException; + /** A ConnectionDecrypter that performs no decryption. */ class NullConnectionDecrypter implements ConnectionDecrypter { private final InputStream in; + private final int macLength; - NullConnectionDecrypter(InputStream in) { + NullConnectionDecrypter(InputStream in, int macLength) { this.in = in; + this.macLength = macLength; } - public InputStream getInputStream() { - return in; - } - - public void readFinal(byte[] mac) throws IOException { - int offset = 0; - while(offset < mac.length) { - int read = in.read(mac, offset, mac.length - offset); - if(read == -1) break; + public int readFrame(byte[] b) throws IOException { + if(b.length < MAX_FRAME_LENGTH) throw new IllegalArgumentException(); + // Read the header to determine the frame length + int offset = 0, length = FRAME_HEADER_LENGTH; + while(offset < length) { + int read = in.read(b, offset, length - offset); + if(read == -1) { + if(offset == 0) return -1; + throw new EOFException(); + } + offset += read; + } + // Parse the header + int payload = HeaderEncoder.getPayloadLength(b); + int padding = HeaderEncoder.getPaddingLength(b); + length = FRAME_HEADER_LENGTH + payload + padding + macLength; + if(length > MAX_FRAME_LENGTH) throw new FormatException(); + // Read the remainder of the frame + while(offset < length) { + int read = in.read(b, offset, length - offset); + if(read == -1) throw new EOFException(); offset += read; } - if(offset < mac.length) throw new EOFException(); + return length; } } diff --git a/test/net/sf/briar/transport/NullConnectionEncrypter.java b/test/net/sf/briar/transport/NullConnectionEncrypter.java index afcbb477657c4d9ae0111e8ea80eb53da37813a5..5f59c78969ce1d2d93d12a7f5199a3a00cf07038 100644 --- a/test/net/sf/briar/transport/NullConnectionEncrypter.java +++ b/test/net/sf/briar/transport/NullConnectionEncrypter.java @@ -20,8 +20,8 @@ class NullConnectionEncrypter implements ConnectionEncrypter { this.capacity = capacity; } - public void writeFrame(byte[] b, int off, int len) throws IOException { - out.write(b, off, len); + public void writeFrame(byte[] b, int len) throws IOException { + out.write(b, 0, len); capacity -= len; }