diff --git a/components/net/sf/briar/transport/OutgoingEncryptionLayer.java b/components/net/sf/briar/transport/OutgoingEncryptionLayer.java index 9d6a67260cf6a0737c77598b3db88b802f95bec2..22bb1bcea3b5255a94391f828207199722ff1595 100644 --- a/components/net/sf/briar/transport/OutgoingEncryptionLayer.java +++ b/components/net/sf/briar/transport/OutgoingEncryptionLayer.java @@ -24,7 +24,7 @@ class OutgoingEncryptionLayer implements FrameWriter { private final AuthenticatedCipher frameCipher; private final ErasableKey tagKey, frameKey; private final byte[] iv, aad, ciphertext; - private final int frameLength; + private final int frameLength, maxPayloadLength; private long capacity, frameNumber; private boolean writeTag; @@ -40,6 +40,7 @@ class OutgoingEncryptionLayer implements FrameWriter { this.tagKey = tagKey; this.frameKey = frameKey; this.frameLength = frameLength; + maxPayloadLength = frameLength - HEADER_LENGTH - MAC_LENGTH; iv = new byte[IV_LENGTH]; aad = new byte[AAD_LENGTH]; ciphertext = new byte[frameLength]; @@ -56,6 +57,7 @@ class OutgoingEncryptionLayer implements FrameWriter { this.frameCipher = frameCipher; this.frameKey = frameKey; this.frameLength = frameLength; + maxPayloadLength = frameLength - HEADER_LENGTH - MAC_LENGTH; tagCipher = null; tagKey = null; iv = new byte[IV_LENGTH]; @@ -125,10 +127,26 @@ class OutgoingEncryptionLayer implements FrameWriter { } public long getRemainingCapacity() { - long capacityExcludingTag = writeTag ? capacity - TAG_LENGTH : capacity; - long frames = capacityExcludingTag / frameLength; + // How many frame numbers can we use? long frameNumbers = MAX_32_BIT_UNSIGNED - frameNumber + 1; - int maxPayloadLength = frameLength - HEADER_LENGTH - MAC_LENGTH; - return maxPayloadLength * Math.min(frames, frameNumbers); + // How many full frames do we have space for? + long bytes = writeTag ? capacity - TAG_LENGTH : capacity; + long fullFrames = bytes / frameLength; + // Are we limited by frame numbers or space? + if(frameNumbers > fullFrames) { + // Can we send a partial frame after the full frames? + int partialFrame = (int) (bytes - fullFrames * frameLength); + if(partialFrame > HEADER_LENGTH + MAC_LENGTH) { + // Send full frames and a partial frame, limited by space + int partialPayload = partialFrame - HEADER_LENGTH - MAC_LENGTH; + return maxPayloadLength * fullFrames + partialPayload; + } else { + // Send full frames only, limited by space + return maxPayloadLength * fullFrames; + } + } else { + // Send full frames only, limited by frame numbers + return maxPayloadLength * frameNumbers; + } } } \ No newline at end of file diff --git a/test/net/sf/briar/transport/OutgoingEncryptionLayerTest.java b/test/net/sf/briar/transport/OutgoingEncryptionLayerTest.java index ae1516e2ca8f8002f7738ef7c0362de310c263e8..3d22d802fbf2762379146f17ac99dda09bc7ce73 100644 --- a/test/net/sf/briar/transport/OutgoingEncryptionLayerTest.java +++ b/test/net/sf/briar/transport/OutgoingEncryptionLayerTest.java @@ -1,12 +1,103 @@ package net.sf.briar.transport; -import org.junit.Test; +import static net.sf.briar.api.transport.TransportConstants.HEADER_LENGTH; +import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH; +import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH; + +import java.io.ByteArrayOutputStream; + +import javax.crypto.Cipher; import net.sf.briar.BriarTestCase; +import net.sf.briar.api.crypto.AuthenticatedCipher; +import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.crypto.CryptoModule; + +import org.junit.Test; + +import com.google.inject.Guice; +import com.google.inject.Injector; public class OutgoingEncryptionLayerTest extends BriarTestCase { - // FIXME: Write tests + // FIXME: Write more tests + + private final CryptoComponent crypto; + private final Cipher tagCipher; + private final AuthenticatedCipher frameCipher; + + public OutgoingEncryptionLayerTest() { + super(); + Injector i = Guice.createInjector(new CryptoModule()); + crypto = i.getInstance(CryptoComponent.class); + tagCipher = crypto.getTagCipher(); + frameCipher = crypto.getFrameCipher(); + } + + @Test + public void testRemainingCapacityWithTag() throws Exception { + int frameLength = 1024; + int maxPayloadLength = frameLength - HEADER_LENGTH - MAC_LENGTH; + long capacity = 10 * frameLength; + ByteArrayOutputStream out = new ByteArrayOutputStream(); + OutgoingEncryptionLayer o = new OutgoingEncryptionLayer(out, capacity, + tagCipher, frameCipher, crypto.generateTestKey(), + crypto.generateTestKey(), frameLength); + // There should be space for nine full frames and one partial frame + byte[] frame = new byte[frameLength]; + assertEquals(10 * maxPayloadLength - TAG_LENGTH, + o.getRemainingCapacity()); + // Write nine frames, each containing a partial payload + for(int i = 9; i > 0; i--) { + o.writeFrame(frame, 123, false); + assertEquals(i * maxPayloadLength - TAG_LENGTH, + o.getRemainingCapacity()); + } + // Write the final frame, which will not be padded + o.writeFrame(frame, 123, true); + int finalFrameLength = HEADER_LENGTH + 123 + MAC_LENGTH; + assertEquals(maxPayloadLength - TAG_LENGTH - finalFrameLength, + o.getRemainingCapacity()); + } + + @Test + public void testRemainingCapacityWithoutTag() throws Exception { + int frameLength = 1024; + int maxPayloadLength = frameLength - HEADER_LENGTH - MAC_LENGTH; + long capacity = 10 * frameLength; + ByteArrayOutputStream out = new ByteArrayOutputStream(); + OutgoingEncryptionLayer o = new OutgoingEncryptionLayer(out, capacity, + frameCipher, crypto.generateTestKey(), frameLength); + // There should be space for ten full frames + assertEquals(10 * maxPayloadLength, o.getRemainingCapacity()); + // Write nine frames, each containing a partial payload + byte[] frame = new byte[frameLength]; + for(int i = 9; i > 0; i--) { + o.writeFrame(frame, 123, false); + assertEquals(i * maxPayloadLength, o.getRemainingCapacity()); + } + // Write the final frame, which will not be padded + o.writeFrame(frame, 123, true); + int finalFrameLength = HEADER_LENGTH + 123 + MAC_LENGTH; + assertEquals(maxPayloadLength - finalFrameLength, + o.getRemainingCapacity()); + } + @Test - public void testNothing() {} + public void testRemainingCapacityLimitedByFrameNumbers() throws Exception { + int frameLength = 1024; + int maxPayloadLength = frameLength - HEADER_LENGTH - MAC_LENGTH; + long capacity = Long.MAX_VALUE; + ByteArrayOutputStream out = new ByteArrayOutputStream(); + OutgoingEncryptionLayer o = new OutgoingEncryptionLayer(out, capacity, + frameCipher, crypto.generateTestKey(), frameLength); + // There should be enough frame numbers for 2^32 frames + assertEquals((1L << 32) * maxPayloadLength, o.getRemainingCapacity()); + // Write a frame containing a partial payload + byte[] frame = new byte[frameLength]; + o.writeFrame(frame, 123, false); + // There should be enough frame numbers for 2^32 - 1 frames + assertEquals(((1L << 32) - 1) * maxPayloadLength, + o.getRemainingCapacity()); + } }