From 2c387f80b13c0635404809bbd9af8eb4bfb5d2b3 Mon Sep 17 00:00:00 2001 From: akwizgran <akwizgran@users.sourceforge.net> Date: Fri, 12 Aug 2011 17:14:58 +0200 Subject: [PATCH] Packet decrypter with unit tests. Decryption is complicated by the fact that the cipher wants to operate a block at a time even though it's in CTR mode. --- .../sf/briar/transport/PacketDecrypter.java | 13 ++ .../briar/transport/PacketDecrypterImpl.java | 162 ++++++++++++++++++ .../briar/transport/PacketEncrypterImpl.java | 3 + test/build.xml | 1 + .../transport/PacketDecrypterImplTest.java | 113 ++++++++++++ .../transport/PacketEncrypterImplTest.java | 2 +- 6 files changed, 293 insertions(+), 1 deletion(-) create mode 100644 components/net/sf/briar/transport/PacketDecrypter.java create mode 100644 components/net/sf/briar/transport/PacketDecrypterImpl.java create mode 100644 test/net/sf/briar/transport/PacketDecrypterImplTest.java diff --git a/components/net/sf/briar/transport/PacketDecrypter.java b/components/net/sf/briar/transport/PacketDecrypter.java new file mode 100644 index 0000000000..c8587d3ae0 --- /dev/null +++ b/components/net/sf/briar/transport/PacketDecrypter.java @@ -0,0 +1,13 @@ +package net.sf.briar.transport; + +import java.io.IOException; +import java.io.InputStream; + +interface PacketDecrypter { + + /** Returns the input stream from which packets should be read. */ + InputStream getInputStream(); + + /** Reads, decrypts and returns a tag from the underlying input stream. */ + byte[] readTag() throws IOException; +} diff --git a/components/net/sf/briar/transport/PacketDecrypterImpl.java b/components/net/sf/briar/transport/PacketDecrypterImpl.java new file mode 100644 index 0000000000..c0645af462 --- /dev/null +++ b/components/net/sf/briar/transport/PacketDecrypterImpl.java @@ -0,0 +1,162 @@ +package net.sf.briar.transport; + +import java.io.EOFException; +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.security.InvalidAlgorithmParameterException; +import java.security.InvalidKeyException; +import java.util.Arrays; + +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.SecretKey; +import javax.crypto.ShortBufferException; +import javax.crypto.spec.IvParameterSpec; + +class PacketDecrypterImpl extends FilterInputStream implements PacketDecrypter { + + private final Cipher tagCipher, packetCipher; + private final SecretKey packetKey; + + private byte[] cipherBuf, plainBuf; + private int bufOff = 0, bufLen = Constants.TAG_BYTES; + private boolean betweenPackets = true; + + PacketDecrypterImpl(byte[] firstTag, InputStream in, Cipher tagCipher, + Cipher packetCipher, SecretKey tagKey, SecretKey packetKey) { + super(in); + if(firstTag.length != Constants.TAG_BYTES) + throw new IllegalArgumentException(); + cipherBuf = Arrays.copyOf(firstTag, firstTag.length); + plainBuf = new byte[Constants.TAG_BYTES]; + this.tagCipher = tagCipher; + this.packetCipher = packetCipher; + this.packetKey = packetKey; + try { + tagCipher.init(Cipher.DECRYPT_MODE, tagKey); + } catch(InvalidKeyException e) { + throw new IllegalArgumentException(e); + } + if(tagCipher.getOutputSize(Constants.TAG_BYTES) != Constants.TAG_BYTES) + throw new IllegalArgumentException(); + } + + public InputStream getInputStream() { + return this; + } + + public byte[] readTag() throws IOException { + byte[] tag = new byte[Constants.TAG_BYTES]; + System.arraycopy(cipherBuf, bufOff, tag, 0, bufLen); + int offset = bufLen; + bufOff = bufLen = 0; + while(offset < tag.length) { + int read = in.read(tag, offset, tag.length - offset); + if(read == -1) throw new EOFException(); + offset += read; + } + betweenPackets = false; + try { + byte[] decryptedTag = tagCipher.doFinal(tag); + IvParameterSpec iv = new IvParameterSpec(decryptedTag); + packetCipher.init(Cipher.DECRYPT_MODE, packetKey, iv); + return decryptedTag; + } catch(BadPaddingException badCipher) { + throw new RuntimeException(badCipher); + } catch(IllegalBlockSizeException badCipher) { + throw new RuntimeException(badCipher); + } catch(InvalidAlgorithmParameterException badIv) { + throw new RuntimeException(badIv); + } catch(InvalidKeyException badKey) { + throw new RuntimeException(badKey); + } + } + + @Override + public int read() throws IOException { + if(betweenPackets) throw new IllegalStateException(); + if(bufLen == 0) { + int read = readBlock(); + if(read == 0) return -1; + bufOff = 0; + bufLen = read; + } + int i = plainBuf[bufOff]; + bufOff++; + bufLen--; + return i < 0 ? i + 256 : i; + } + + // Although we're using CTR mode, which doesn't require full blocks of + // ciphertext, the cipher still tries to operate a block at a time. We must + // either call update() with a full block or doFinal() with the last + // (possibly partial) block. + private int readBlock() throws IOException { + // Try to read a block of ciphertext + int off = 0; + while(off < cipherBuf.length) { + int read = in.read(cipherBuf, off, cipherBuf.length - off); + if(read == -1) break; + off += read; + } + if(off == 0) return 0; + // Did we get a whole block? If not we must be at EOF + if(off < cipherBuf.length) { + // We're at EOF so we can call doFinal() to force decryption + try { + int i = packetCipher.doFinal(cipherBuf, 0, off, plainBuf); + if(i < off) throw new RuntimeException(); + betweenPackets = true; + } catch(BadPaddingException badCipher) { + throw new RuntimeException(badCipher); + } catch(IllegalBlockSizeException badCipher) { + throw new RuntimeException(badCipher); + } catch(ShortBufferException badCipher) { + throw new RuntimeException(badCipher); + } + } else { + // We're not at EOF but we have a whole block to decrypt + try { + int i = packetCipher.update(cipherBuf, 0, off, plainBuf); + if(i < off) throw new RuntimeException(); + } catch(ShortBufferException badCipher) { + throw new RuntimeException(badCipher); + } + } + return off; + } + + @Override + public int read(byte[] b) throws IOException { + if(betweenPackets) throw new IllegalStateException(); + if(bufLen == 0) { + int read = readBlock(); + if(read == 0) return -1; + bufOff = 0; + bufLen = read; + } + int length = Math.min(b.length, bufLen); + System.arraycopy(plainBuf, bufOff, b, 0, length); + bufOff += length; + bufLen -= length; + return length; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if(betweenPackets) throw new IllegalStateException(); + if(bufLen == 0) { + int read = readBlock(); + if(read == 0) return -1; + bufOff = 0; + bufLen = read; + } + int length = Math.min(len, bufLen); + System.arraycopy(plainBuf, bufOff, b, off, length); + bufOff += length; + bufLen -= length; + return length; + } +} diff --git a/components/net/sf/briar/transport/PacketEncrypterImpl.java b/components/net/sf/briar/transport/PacketEncrypterImpl.java index 69ca0c9ca3..5629bcae44 100644 --- a/components/net/sf/briar/transport/PacketEncrypterImpl.java +++ b/components/net/sf/briar/transport/PacketEncrypterImpl.java @@ -67,18 +67,21 @@ implements PacketEncrypter { @Override public void write(int b) throws IOException { + // FIXME: Encrypt into same buffer byte[] ciphertext = packetCipher.update(new byte[] {(byte) b}); if(ciphertext != null) out.write(ciphertext); } @Override public void write(byte[] b) throws IOException { + // FIXME: Encrypt into same buffer byte[] ciphertext = packetCipher.update(b); if(ciphertext != null) out.write(ciphertext); } @Override public void write(byte[] b, int off, int len) throws IOException { + // FIXME: Encrypt into same buffer byte[] ciphertext = packetCipher.update(b, off, len); if(ciphertext != null) out.write(ciphertext); } diff --git a/test/build.xml b/test/build.xml index d008290f42..90ccb67ce2 100644 --- a/test/build.xml +++ b/test/build.xml @@ -34,6 +34,7 @@ <test name='net.sf.briar.setup.SetupWorkerTest'/> <test name='net.sf.briar.transport.ConnectionRecogniserImplTest'/> <test name='net.sf.briar.transport.ConnectionWindowImplTest'/> + <test name='net.sf.briar.transport.PacketDecrypterImplTest'/> <test name='net.sf.briar.transport.PacketEncrypterImplTest'/> <test name='net.sf.briar.transport.PacketWriterImplTest'/> <test name='net.sf.briar.transport.TagEncoderTest'/> diff --git a/test/net/sf/briar/transport/PacketDecrypterImplTest.java b/test/net/sf/briar/transport/PacketDecrypterImplTest.java new file mode 100644 index 0000000000..e4ec4f2d3c --- /dev/null +++ b/test/net/sf/briar/transport/PacketDecrypterImplTest.java @@ -0,0 +1,113 @@ +package net.sf.briar.transport; + +import java.io.ByteArrayInputStream; +import java.util.Arrays; + +import javax.crypto.Cipher; +import javax.crypto.SecretKey; +import javax.crypto.spec.IvParameterSpec; + +import junit.framework.TestCase; +import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.crypto.CryptoModule; + +import org.junit.Test; + +import com.google.inject.Guice; +import com.google.inject.Injector; + +public class PacketDecrypterImplTest extends TestCase { + + private final Cipher tagCipher, packetCipher; + private final SecretKey tagKey, packetKey; + + public PacketDecrypterImplTest() { + super(); + Injector i = Guice.createInjector(new CryptoModule()); + CryptoComponent crypto = i.getInstance(CryptoComponent.class); + tagCipher = crypto.getTagCipher(); + packetCipher = crypto.getPacketCipher(); + tagKey = crypto.generateSecretKey(); + packetKey = crypto.generateSecretKey(); + } + + @Test + public void testSingleBytePackets() throws Exception { + byte[] ciphertext = new byte[(Constants.TAG_BYTES + 1) * 2]; + ByteArrayInputStream in = new ByteArrayInputStream(ciphertext); + byte[] firstTag = new byte[Constants.TAG_BYTES]; + assertEquals(Constants.TAG_BYTES, in.read(firstTag)); + PacketDecrypter p = new PacketDecrypterImpl(firstTag, in, tagCipher, + packetCipher, tagKey, packetKey); + byte[] decryptedTag = p.readTag(); + assertEquals(Constants.TAG_BYTES, decryptedTag.length); + assertTrue(p.getInputStream().read() > -1); + byte[] decryptedTag1 = p.readTag(); + assertEquals(Constants.TAG_BYTES, decryptedTag1.length); + assertTrue(p.getInputStream().read() > -1); + } + + @Test + public void testDecryption() throws Exception { + byte[] tag = new byte[Constants.TAG_BYTES]; + byte[] packet = new byte[123]; + byte[] tag1 = new byte[Constants.TAG_BYTES]; + byte[] packet1 = new byte[234]; + // Calculate the first expected decrypted tag + tagCipher.init(Cipher.DECRYPT_MODE, tagKey); + byte[] expectedTag = tagCipher.doFinal(tag); + assertEquals(tag.length, expectedTag.length); + // Calculate the first expected decrypted packet + IvParameterSpec iv = new IvParameterSpec(expectedTag); + packetCipher.init(Cipher.DECRYPT_MODE, packetKey, iv); + byte[] expectedPacket = packetCipher.doFinal(packet); + assertEquals(packet.length, expectedPacket.length); + // Calculate the second expected decrypted tag + tagCipher.init(Cipher.DECRYPT_MODE, tagKey); + byte[] expectedTag1 = tagCipher.doFinal(tag1); + assertEquals(tag1.length, expectedTag1.length); + // Calculate the second expected decrypted packet + IvParameterSpec iv1 = new IvParameterSpec(expectedTag1); + packetCipher.init(Cipher.DECRYPT_MODE, packetKey, iv1); + byte[] expectedPacket1 = packetCipher.doFinal(packet1); + assertEquals(packet1.length, expectedPacket1.length); + // Check that the PacketDecrypter gets the same results + byte[] ciphertext = new byte[tag.length + packet.length + + tag1.length + packet1.length]; + System.arraycopy(tag, 0, ciphertext, 0, tag.length); + System.arraycopy(packet, 0, ciphertext, tag.length, packet.length); + System.arraycopy(tag1, 0, ciphertext, tag.length + packet.length, + tag1.length); + System.arraycopy(packet1, 0, ciphertext, + tag.length + packet.length + tag1.length, packet1.length); + ByteArrayInputStream in = new ByteArrayInputStream(ciphertext); + PacketDecrypter p = new PacketDecrypterImpl(tag, in, tagCipher, + packetCipher, tagKey, packetKey); + // First tag + assertTrue(Arrays.equals(expectedTag, p.readTag())); + // First packet + byte[] actualPacket = new byte[packet.length]; + int offset = 0; + while(offset < actualPacket.length) { + int read = p.getInputStream().read(actualPacket, offset, + actualPacket.length - offset); + if(read == -1) break; + offset += read; + } + assertEquals(actualPacket.length, offset); + assertTrue(Arrays.equals(expectedPacket, actualPacket)); + // Second tag + assertTrue(Arrays.equals(expectedTag1, p.readTag())); + // Second packet + byte[] actualPacket1 = new byte[packet1.length]; + offset = 0; + while(offset < actualPacket1.length) { + int read = p.getInputStream().read(actualPacket1, offset, + actualPacket1.length - offset); + if(read == -1) break; + offset += read; + } + assertEquals(actualPacket1.length, offset); + assertTrue(Arrays.equals(expectedPacket1, actualPacket1)); + } +} diff --git a/test/net/sf/briar/transport/PacketEncrypterImplTest.java b/test/net/sf/briar/transport/PacketEncrypterImplTest.java index 2132223f2c..2b06dfae38 100644 --- a/test/net/sf/briar/transport/PacketEncrypterImplTest.java +++ b/test/net/sf/briar/transport/PacketEncrypterImplTest.java @@ -39,7 +39,7 @@ public class PacketEncrypterImplTest extends TestCase { p.writeTag(new byte[Constants.TAG_BYTES]); p.getOutputStream().write((byte) 0); p.finishPacket(); - assertEquals(17, out.toByteArray().length); + assertEquals(Constants.TAG_BYTES + 1, out.toByteArray().length); } @Test -- GitLab