From b2cab71637463636d0ecd946b12bcc78d0095745 Mon Sep 17 00:00:00 2001 From: akwizgran <akwizgran@users.sourceforge.net> Date: Fri, 13 Jan 2012 16:58:41 +0000 Subject: [PATCH] Decryption code for tagging every segment. --- .../briar/transport/ConnectionDecrypter.java | 46 ++++++++++------ .../transport/ConnectionEncrypterImpl.java | 2 +- .../ConnectionReaderFactoryImpl.java | 6 ++- .../SegmentedConnectionDecrypter.java | 50 +++++++++-------- .../SegmentedConnectionEncrypter.java | 2 +- .../net/sf/briar/transport/TagEncoder.java | 8 +-- .../transport/ConnectionDecrypterTest.java | 53 +++++++++++++++++-- .../briar/transport/FrameReadWriteTest.java | 4 +- .../SegmentedConnectionDecrypterTest.java | 51 +++++++++++++++--- 9 files changed, 164 insertions(+), 58 deletions(-) diff --git a/components/net/sf/briar/transport/ConnectionDecrypter.java b/components/net/sf/briar/transport/ConnectionDecrypter.java index 227c09bbdf..928f56ca74 100644 --- a/components/net/sf/briar/transport/ConnectionDecrypter.java +++ b/components/net/sf/briar/transport/ConnectionDecrypter.java @@ -2,6 +2,7 @@ package net.sf.briar.transport; import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH; import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH; +import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH; import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED; import java.io.EOFException; @@ -18,19 +19,24 @@ import net.sf.briar.api.crypto.ErasableKey; class ConnectionDecrypter implements FrameSource { private final InputStream in; - private final Cipher frameCipher; - private final ErasableKey frameKey; + private final Cipher tagCipher, frameCipher; + private final ErasableKey tagKey, frameKey; private final int macLength, blockSize; private final byte[] iv; + private final boolean tagEverySegment; private long frame = 0L; - ConnectionDecrypter(InputStream in, Cipher frameCipher, - ErasableKey frameKey, int macLength) { + ConnectionDecrypter(InputStream in, Cipher tagCipher, Cipher frameCipher, + ErasableKey tagKey, ErasableKey frameKey, int macLength, + boolean tagEverySegment) { this.in = in; + this.tagCipher = tagCipher; this.frameCipher = frameCipher; + this.tagKey = tagKey; this.frameKey = frameKey; this.macLength = macLength; + this.tagEverySegment = tagEverySegment; blockSize = frameCipher.getBlockSize(); if(blockSize < FRAME_HEADER_LENGTH) throw new IllegalArgumentException(); @@ -40,30 +46,39 @@ class ConnectionDecrypter implements FrameSource { public int readFrame(byte[] b) throws IOException { if(b.length < MAX_FRAME_LENGTH) throw new IllegalArgumentException(); if(frame > MAX_32_BIT_UNSIGNED) throw new IllegalStateException(); - // Initialise the cipher - IvEncoder.updateIv(iv, frame); - IvParameterSpec ivSpec = new IvParameterSpec(iv); - try { - frameCipher.init(Cipher.DECRYPT_MODE, frameKey, ivSpec); - } catch(GeneralSecurityException badIvOrKey) { - throw new RuntimeException(badIvOrKey); - } + boolean tag = tagEverySegment && frame > 0; // Clear the buffer before exposing it to the transport plugin for(int i = 0; i < b.length; i++) b[i] = 0; try { + // If a tag is expected then read, decrypt and validate it + if(tag) { + int offset = 0; + while(offset < TAG_LENGTH) { + int read = in.read(b, offset, TAG_LENGTH - offset); + if(read == -1) { + if(offset == 0) return -1; + throw new EOFException(); + } + offset += read; + } + if(!TagEncoder.validateTag(b, frame, tagCipher, tagKey)) + throw new FormatException(); + } // Read the first block int offset = 0; while(offset < blockSize) { int read = in.read(b, offset, blockSize - offset); if(read == -1) { - if(offset == 0) return -1; - if(offset < blockSize) throw new EOFException(); - break; + if(offset == 0 && !tag) return -1; + throw new EOFException(); } offset += read; } // Decrypt the first block try { + IvEncoder.updateIv(iv, frame); + IvParameterSpec ivSpec = new IvParameterSpec(iv); + frameCipher.init(Cipher.DECRYPT_MODE, frameKey, ivSpec); int decrypted = frameCipher.update(b, 0, blockSize, b); if(decrypted != blockSize) throw new RuntimeException(); } catch(GeneralSecurityException badCipher) { @@ -96,6 +111,7 @@ class ConnectionDecrypter implements FrameSource { return length; } catch(IOException e) { frameKey.erase(); + tagKey.erase(); throw e; } } diff --git a/components/net/sf/briar/transport/ConnectionEncrypterImpl.java b/components/net/sf/briar/transport/ConnectionEncrypterImpl.java index 3eeba89d40..08e40577ac 100644 --- a/components/net/sf/briar/transport/ConnectionEncrypterImpl.java +++ b/components/net/sf/briar/transport/ConnectionEncrypterImpl.java @@ -20,7 +20,7 @@ class ConnectionEncrypterImpl implements ConnectionEncrypter { private final boolean tagEverySegment; private final byte[] iv, tag; - private long capacity, frame = 0; + private long capacity, frame = 0L; ConnectionEncrypterImpl(OutputStream out, long capacity, Cipher tagCipher, Cipher frameCipher, ErasableKey tagKey, ErasableKey frameKey, diff --git a/components/net/sf/briar/transport/ConnectionReaderFactoryImpl.java b/components/net/sf/briar/transport/ConnectionReaderFactoryImpl.java index 36be2d4143..4149b37689 100644 --- a/components/net/sf/briar/transport/ConnectionReaderFactoryImpl.java +++ b/components/net/sf/briar/transport/ConnectionReaderFactoryImpl.java @@ -43,12 +43,14 @@ class ConnectionReaderFactoryImpl implements ConnectionReaderFactory { // Derive the keys and erase the secret ErasableKey frameKey = crypto.deriveFrameKey(secret, initiator); ErasableKey macKey = crypto.deriveMacKey(secret, initiator); + ErasableKey tagKey = crypto.deriveTagKey(secret, initiator); ByteUtils.erase(secret); // Create the decrypter + Cipher tagCipher = crypto.getTagCipher(); Cipher frameCipher = crypto.getFrameCipher(); Mac mac = crypto.getMac(); - FrameSource decrypter = new ConnectionDecrypter(in, - frameCipher, frameKey, mac.getMacLength()); + FrameSource decrypter = new ConnectionDecrypter(in, tagCipher, + frameCipher, tagKey, frameKey, mac.getMacLength(), false); // Create the reader return new ConnectionReaderImpl(decrypter, mac, macKey); } diff --git a/components/net/sf/briar/transport/SegmentedConnectionDecrypter.java b/components/net/sf/briar/transport/SegmentedConnectionDecrypter.java index 47bc7b1423..59af5f185c 100644 --- a/components/net/sf/briar/transport/SegmentedConnectionDecrypter.java +++ b/components/net/sf/briar/transport/SegmentedConnectionDecrypter.java @@ -2,6 +2,7 @@ package net.sf.briar.transport; import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH; import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH; +import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH; import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED; import java.io.IOException; @@ -18,20 +19,25 @@ import net.sf.briar.api.plugins.SegmentSource; class SegmentedConnectionDecrypter implements FrameSource { private final SegmentSource in; - private final Cipher frameCipher; - private final ErasableKey frameKey; + private final Cipher tagCipher, frameCipher; + private final ErasableKey tagKey, frameKey; private final int macLength, blockSize; private final byte[] iv; private final Segment segment; + private final boolean tagEverySegment; private long frame = 0L; - SegmentedConnectionDecrypter(SegmentSource in, Cipher frameCipher, - ErasableKey frameKey, int macLength) { + SegmentedConnectionDecrypter(SegmentSource in, Cipher tagCipher, + Cipher frameCipher, ErasableKey tagKey, ErasableKey frameKey, + int macLength, boolean tagEverySegment) { this.in = in; + this.tagCipher = tagCipher; this.frameCipher = frameCipher; + this.tagKey = tagKey; this.frameKey = frameKey; this.macLength = macLength; + this.tagEverySegment = tagEverySegment; blockSize = frameCipher.getBlockSize(); if(blockSize < FRAME_HEADER_LENGTH) throw new IllegalArgumentException(); @@ -42,30 +48,27 @@ class SegmentedConnectionDecrypter implements FrameSource { public int readFrame(byte[] b) throws IOException { if(b.length < MAX_FRAME_LENGTH) throw new IllegalArgumentException(); if(frame > MAX_32_BIT_UNSIGNED) throw new IllegalStateException(); - // Initialise the cipher - IvEncoder.updateIv(iv, frame); - IvParameterSpec ivSpec = new IvParameterSpec(iv); - try { - frameCipher.init(Cipher.DECRYPT_MODE, frameKey, ivSpec); - } catch(GeneralSecurityException badIvOrKey) { - throw new RuntimeException(badIvOrKey); - } + boolean tag = tagEverySegment && frame > 0; // Clear the buffer before exposing it to the transport plugin segment.clear(); try { - // Read the frame + // Read the segment if(!in.readSegment(segment)) return -1; - if(segment.getTransmissionNumber() != frame) - throw new FormatException(); - int length = segment.getLength(); + int offset = tag ? TAG_LENGTH : 0, length = segment.getLength(); if(length > MAX_FRAME_LENGTH) throw new FormatException(); - if(length < FRAME_HEADER_LENGTH + macLength) + if(length < offset + FRAME_HEADER_LENGTH + macLength) throw new FormatException(); + // If a tag is expected, decrypt and validate it + if(tag && !TagEncoder.validateTag(segment.getBuffer(), frame, + tagCipher, tagKey)) throw new FormatException(); // Decrypt the frame try { - int decrypted = frameCipher.doFinal(segment.getBuffer(), 0, - length, b); - if(decrypted != length) throw new RuntimeException(); + IvEncoder.updateIv(iv, frame); + IvParameterSpec ivSpec = new IvParameterSpec(iv); + frameCipher.init(Cipher.DECRYPT_MODE, frameKey, ivSpec); + int decrypted = frameCipher.doFinal(segment.getBuffer(), offset, + length - offset, b); + if(decrypted != length - offset) throw new RuntimeException(); } catch(GeneralSecurityException badCipher) { throw new RuntimeException(badCipher); } @@ -75,12 +78,13 @@ class SegmentedConnectionDecrypter implements FrameSource { throw new FormatException(); int payload = HeaderEncoder.getPayloadLength(b); int padding = HeaderEncoder.getPaddingLength(b); - if(length != FRAME_HEADER_LENGTH + payload + padding + macLength) - throw new FormatException(); + if(length != offset + FRAME_HEADER_LENGTH + payload + padding + + macLength) throw new FormatException(); frame++; - return length; + return length - offset; } catch(IOException e) { frameKey.erase(); + tagKey.erase(); throw e; } } diff --git a/components/net/sf/briar/transport/SegmentedConnectionEncrypter.java b/components/net/sf/briar/transport/SegmentedConnectionEncrypter.java index f757425aac..29af654f2a 100644 --- a/components/net/sf/briar/transport/SegmentedConnectionEncrypter.java +++ b/components/net/sf/briar/transport/SegmentedConnectionEncrypter.java @@ -23,7 +23,7 @@ class SegmentedConnectionEncrypter implements ConnectionEncrypter { private final byte[] iv; private final Segment segment; - private long capacity, frame = 0; + private long capacity, frame = 0L; SegmentedConnectionEncrypter(SegmentSink out, long capacity, Cipher tagCipher, Cipher frameCipher, ErasableKey tagKey, diff --git a/components/net/sf/briar/transport/TagEncoder.java b/components/net/sf/briar/transport/TagEncoder.java index df83c7d8fd..6afa2bead7 100644 --- a/components/net/sf/briar/transport/TagEncoder.java +++ b/components/net/sf/briar/transport/TagEncoder.java @@ -39,15 +39,17 @@ class TagEncoder { ErasableKey tagKey) { if(frame < 0 || frame > MAX_32_BIT_UNSIGNED) throw new IllegalArgumentException(); - if(tag.length != TAG_LENGTH) return false; + if(tag.length < TAG_LENGTH) return false; // Encode the frame number as a uint32 at the end of the IV byte[] iv = new byte[tagCipher.getBlockSize()]; - if(iv.length != tag.length) throw new IllegalArgumentException(); + if(iv.length != TAG_LENGTH) throw new IllegalArgumentException(); ByteUtils.writeUint32(frame, iv, iv.length - 4); IvParameterSpec ivSpec = new IvParameterSpec(iv); try { tagCipher.init(Cipher.DECRYPT_MODE, tagKey, ivSpec); - byte[] plaintext = tagCipher.doFinal(tag); + byte[] plaintext = tagCipher.doFinal(tag, 0, TAG_LENGTH); + if(plaintext.length != TAG_LENGTH) + throw new IllegalArgumentException(); // The plaintext should be blank for(int i = 0; i < plaintext.length; i++) { if(plaintext[i] != 0) return false; diff --git a/test/net/sf/briar/transport/ConnectionDecrypterTest.java b/test/net/sf/briar/transport/ConnectionDecrypterTest.java index 829c73335a..b79bff4d42 100644 --- a/test/net/sf/briar/transport/ConnectionDecrypterTest.java +++ b/test/net/sf/briar/transport/ConnectionDecrypterTest.java @@ -2,6 +2,7 @@ package net.sf.briar.transport; import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH; import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH; +import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH; import java.io.ByteArrayInputStream; @@ -23,19 +24,21 @@ public class ConnectionDecrypterTest extends BriarTestCase { private static final int MAC_LENGTH = 32; - private final Cipher frameCipher; - private final ErasableKey frameKey; + private final Cipher tagCipher, frameCipher; + private final ErasableKey tagKey, frameKey; public ConnectionDecrypterTest() { super(); Injector i = Guice.createInjector(new CryptoModule()); CryptoComponent crypto = i.getInstance(CryptoComponent.class); + tagCipher = crypto.getTagCipher(); frameCipher = crypto.getFrameCipher(); + tagKey = crypto.generateTestKey(); frameKey = crypto.generateTestKey(); } @Test - public void testDecryption() throws Exception { + public void testDecryptionWithFirstSegmentTagged() throws Exception { // Calculate the ciphertext for the first frame byte[] plaintext = new byte[FRAME_HEADER_LENGTH + 123 + MAC_LENGTH]; HeaderEncoder.encodeHeader(plaintext, 0L, 123, 0); @@ -57,8 +60,48 @@ public class ConnectionDecrypterTest extends BriarTestCase { out.write(ciphertext1); ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); // Use a ConnectionDecrypter to decrypt the ciphertext - FrameSource decrypter = new ConnectionDecrypter(in, frameCipher, - frameKey, MAC_LENGTH); + FrameSource decrypter = new ConnectionDecrypter(in, tagCipher, + frameCipher, tagKey, frameKey, MAC_LENGTH, false); + // First frame + byte[] decrypted = new byte[MAX_FRAME_LENGTH]; + assertEquals(plaintext.length, decrypter.readFrame(decrypted)); + for(int i = 0; i < plaintext.length; i++) { + assertEquals(plaintext[i], decrypted[i]); + } + // Second frame + assertEquals(plaintext1.length, decrypter.readFrame(decrypted)); + for(int i = 0; i < plaintext1.length; i++) { + assertEquals(plaintext1[i], decrypted[i]); + } + } + + @Test + public void testDecryptionWithEverySegmentTagged() throws Exception { + // Calculate the ciphertext for the first frame + byte[] plaintext = new byte[FRAME_HEADER_LENGTH + 123 + MAC_LENGTH]; + HeaderEncoder.encodeHeader(plaintext, 0L, 123, 0); + byte[] iv = IvEncoder.encodeIv(0L, frameCipher.getBlockSize()); + IvParameterSpec ivSpec = new IvParameterSpec(iv); + frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec); + byte[] ciphertext = frameCipher.doFinal(plaintext, 0, plaintext.length); + // Calculate the ciphertext for the second frame, including its tag + byte[] plaintext1 = new byte[FRAME_HEADER_LENGTH + 1234 + MAC_LENGTH]; + HeaderEncoder.encodeHeader(plaintext1, 1L, 1234, 0); + byte[] ciphertext1 = new byte[TAG_LENGTH + plaintext1.length]; + TagEncoder.encodeTag(ciphertext1, 1, tagCipher, tagKey); + IvEncoder.updateIv(iv, 1L); + ivSpec = new IvParameterSpec(iv); + frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec); + frameCipher.doFinal(plaintext1, 0, plaintext1.length, ciphertext1, + TAG_LENGTH); + // Concatenate the ciphertexts + ByteArrayOutputStream out = new ByteArrayOutputStream(); + out.write(ciphertext); + out.write(ciphertext1); + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + // Use a ConnectionDecrypter to decrypt the ciphertext + FrameSource decrypter = new ConnectionDecrypter(in, tagCipher, + frameCipher, tagKey, frameKey, MAC_LENGTH, true); // First frame byte[] decrypted = new byte[MAX_FRAME_LENGTH]; assertEquals(plaintext.length, decrypter.readFrame(decrypted)); diff --git a/test/net/sf/briar/transport/FrameReadWriteTest.java b/test/net/sf/briar/transport/FrameReadWriteTest.java index bfa229e6fe..051f9f66bc 100644 --- a/test/net/sf/briar/transport/FrameReadWriteTest.java +++ b/test/net/sf/briar/transport/FrameReadWriteTest.java @@ -91,8 +91,8 @@ public class FrameReadWriteTest extends BriarTestCase { assertArrayEquals(tag, recoveredTag); assertTrue(TagEncoder.validateTag(tag, 0, tagCipher, tagKey)); // Read the frames back - FrameSource decrypter = new ConnectionDecrypter(in, frameCipher, - frameKey, mac.getMacLength()); + FrameSource decrypter = new ConnectionDecrypter(in, tagCipher, + frameCipher, tagKey, frameKey, mac.getMacLength(), false); ConnectionReader reader = new ConnectionReaderImpl(decrypter, mac, macKey); InputStream in1 = reader.getInputStream(); diff --git a/test/net/sf/briar/transport/SegmentedConnectionDecrypterTest.java b/test/net/sf/briar/transport/SegmentedConnectionDecrypterTest.java index 20727365e8..539b6dd0b2 100644 --- a/test/net/sf/briar/transport/SegmentedConnectionDecrypterTest.java +++ b/test/net/sf/briar/transport/SegmentedConnectionDecrypterTest.java @@ -2,6 +2,7 @@ package net.sf.briar.transport; import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH; import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH; +import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH; import java.io.IOException; @@ -24,19 +25,21 @@ public class SegmentedConnectionDecrypterTest extends BriarTestCase { private static final int MAC_LENGTH = 32; - private final Cipher frameCipher; - private final ErasableKey frameKey; + private final Cipher tagCipher, frameCipher; + private final ErasableKey tagKey, frameKey; public SegmentedConnectionDecrypterTest() { super(); Injector i = Guice.createInjector(new CryptoModule()); CryptoComponent crypto = i.getInstance(CryptoComponent.class); + tagCipher = crypto.getTagCipher(); frameCipher = crypto.getFrameCipher(); + tagKey = crypto.generateTestKey(); frameKey = crypto.generateTestKey(); } @Test - public void testDecryption() throws Exception { + public void testDecryptionWithFirstSegmentTagged() throws Exception { // Calculate the ciphertext for the first frame byte[] plaintext = new byte[FRAME_HEADER_LENGTH + 123 + MAC_LENGTH]; HeaderEncoder.encodeHeader(plaintext, 0L, 123, 0); @@ -55,8 +58,45 @@ public class SegmentedConnectionDecrypterTest extends BriarTestCase { // Use a connection decrypter to decrypt the ciphertext byte[][] frames = new byte[][] { ciphertext, ciphertext1 }; SegmentSource in = new ByteArraySegmentSource(frames); - FrameSource decrypter = new SegmentedConnectionDecrypter(in, - frameCipher, frameKey, MAC_LENGTH); + FrameSource decrypter = new SegmentedConnectionDecrypter(in, tagCipher, + frameCipher, tagKey, frameKey, MAC_LENGTH, false); + // First frame + byte[] decrypted = new byte[MAX_FRAME_LENGTH]; + assertEquals(plaintext.length, decrypter.readFrame(decrypted)); + for(int i = 0; i < plaintext.length; i++) { + assertEquals(plaintext[i], decrypted[i]); + } + // Second frame + assertEquals(plaintext1.length, decrypter.readFrame(decrypted)); + for(int i = 0; i < plaintext1.length; i++) { + assertEquals(plaintext1[i], decrypted[i]); + } + } + + @Test + public void testDecryptionWithEverySegmentTagged() throws Exception { + // Calculate the ciphertext for the first frame + byte[] plaintext = new byte[FRAME_HEADER_LENGTH + 123 + MAC_LENGTH]; + HeaderEncoder.encodeHeader(plaintext, 0L, 123, 0); + byte[] iv = IvEncoder.encodeIv(0L, frameCipher.getBlockSize()); + IvParameterSpec ivSpec = new IvParameterSpec(iv); + frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec); + byte[] ciphertext = frameCipher.doFinal(plaintext, 0, plaintext.length); + // Calculate the ciphertext for the second frame, including its tag + byte[] plaintext1 = new byte[FRAME_HEADER_LENGTH + 1234 + MAC_LENGTH]; + HeaderEncoder.encodeHeader(plaintext1, 1L, 1234, 0); + byte[] ciphertext1 = new byte[TAG_LENGTH + plaintext1.length]; + TagEncoder.encodeTag(ciphertext1, 1, tagCipher, tagKey); + IvEncoder.updateIv(iv, 1L); + ivSpec = new IvParameterSpec(iv); + frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec); + frameCipher.doFinal(plaintext1, 0, plaintext1.length, ciphertext1, + TAG_LENGTH); + // Use a connection decrypter to decrypt the ciphertext + byte[][] frames = new byte[][] { ciphertext, ciphertext1 }; + SegmentSource in = new ByteArraySegmentSource(frames); + FrameSource decrypter = new SegmentedConnectionDecrypter(in, tagCipher, + frameCipher, tagKey, frameKey, MAC_LENGTH, true); // First frame byte[] decrypted = new byte[MAX_FRAME_LENGTH]; assertEquals(plaintext.length, decrypter.readFrame(decrypted)); @@ -85,7 +125,6 @@ public class SegmentedConnectionDecrypterTest extends BriarTestCase { byte[] src = frames[frame]; System.arraycopy(src, 0, s.getBuffer(), 0, src.length); s.setLength(src.length); - s.setTransmissionNumber(frame); frame++; return true; } -- GitLab