From f55f98f506dc9096c5a7e10cb8706cf4ac3c5d88 Mon Sep 17 00:00:00 2001
From: akwizgran <akwizgran@users.sourceforge.net>
Date: Wed, 11 Jan 2012 17:50:24 +0000
Subject: [PATCH] Frame-at-a-time encryption.

---
 .../briar/transport/ConnectionEncrypter.java  |  9 +-
 .../transport/ConnectionEncrypterImpl.java    | 92 +++++--------------
 .../briar/transport/ConnectionWriterImpl.java | 66 +++++++------
 .../ConnectionEncrypterImplTest.java          | 23 ++---
 .../transport/NullConnectionEncrypter.java    | 38 +++-----
 5 files changed, 75 insertions(+), 153 deletions(-)

diff --git a/components/net/sf/briar/transport/ConnectionEncrypter.java b/components/net/sf/briar/transport/ConnectionEncrypter.java
index dd9d1334fd..7d4fdb9425 100644
--- a/components/net/sf/briar/transport/ConnectionEncrypter.java
+++ b/components/net/sf/briar/transport/ConnectionEncrypter.java
@@ -1,16 +1,15 @@
 package net.sf.briar.transport;
 
 import java.io.IOException;
-import java.io.OutputStream;
 
 /** Encrypts authenticated data to be sent over a connection. */
 interface ConnectionEncrypter {
 
-	/** Returns an output stream to which unencrypted data can be written. */
-	OutputStream getOutputStream();
+	/** Encrypts and writes the given frame. */
+	void writeFrame(byte[] b, int off, int len) throws IOException;
 
-	/** Encrypts and writes the remainder of the current frame. */
-	void writeFinal(byte[] b) throws IOException;
+	/** Flushes the output stream. */
+	void flush() throws IOException;
 
 	/** Returns the maximum number of bytes that can be written. */
 	long getRemainingCapacity();
diff --git a/components/net/sf/briar/transport/ConnectionEncrypterImpl.java b/components/net/sf/briar/transport/ConnectionEncrypterImpl.java
index cb0feb08ba..24c66e3515 100644
--- a/components/net/sf/briar/transport/ConnectionEncrypterImpl.java
+++ b/components/net/sf/briar/transport/ConnectionEncrypterImpl.java
@@ -3,7 +3,6 @@ package net.sf.briar.transport;
 import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
 import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
 
-import java.io.FilterOutputStream;
 import java.io.IOException;
 import java.io.OutputStream;
 import java.security.GeneralSecurityException;
@@ -13,19 +12,19 @@ import javax.crypto.spec.IvParameterSpec;
 
 import net.sf.briar.api.crypto.ErasableKey;
 
-class ConnectionEncrypterImpl extends FilterOutputStream
-implements ConnectionEncrypter {
+class ConnectionEncrypterImpl implements ConnectionEncrypter {
 
+	private final OutputStream out;
 	private final Cipher frameCipher;
 	private final ErasableKey frameKey;
 	private final byte[] iv, tag;
 
 	private long capacity, frame = 0L;
-	private boolean tagWritten = false, betweenFrames = false;
+	private boolean tagWritten = false;
 
 	ConnectionEncrypterImpl(OutputStream out, long capacity, Cipher tagCipher,
 			Cipher frameCipher, ErasableKey tagKey, ErasableKey frameKey) {
-		super(out);
+		this.out = out;
 		this.capacity = capacity;
 		this.frameCipher = frameCipher;
 		this.frameKey = frameKey;
@@ -36,84 +35,37 @@ implements ConnectionEncrypter {
 		if(tag.length != TAG_LENGTH) throw new IllegalArgumentException();
 	}
 
-	public OutputStream getOutputStream() {
-		return this;
-	}
-
-	public void writeFinal(byte[] b) throws IOException {
+	public void writeFrame(byte[] b, int off, int len) throws IOException {
 		try {
-			if(!tagWritten || betweenFrames) throw new IllegalStateException();
+			if(!tagWritten) {
+				out.write(tag);
+				capacity -= tag.length;
+				tagWritten = true;
+			}
+			if(frame > MAX_32_BIT_UNSIGNED) throw new IllegalStateException();
+			IvEncoder.updateIv(iv, frame);
+			IvParameterSpec ivSpec = new IvParameterSpec(iv);
 			try {
-				out.write(frameCipher.doFinal(b));
+				frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec);
+				int encrypted = frameCipher.doFinal(b, off, len, b, off);
+				assert encrypted == len;
 			} catch(GeneralSecurityException badCipher) {
 				throw new RuntimeException(badCipher);
 			}
-			capacity -= b.length;
-			betweenFrames = true;
-		} catch(IOException e) {
-			frameKey.erase();
-			throw e;
-		}
-	}
-
-	public long getRemainingCapacity() {
-		return capacity;
-	}
-
-	@Override
-	public void write(int b) throws IOException {
-		try {
-			if(!tagWritten) writeTag();
-			if(betweenFrames) initialiseCipher();
-			byte[] ciphertext = frameCipher.update(new byte[] {(byte) b});
-			if(ciphertext != null) out.write(ciphertext);
-			capacity--;
-		} catch(IOException e) {
-			frameKey.erase();
-			throw e;
-		}
-	}
-
-	@Override
-	public void write(byte[] b) throws IOException {
-		write(b, 0, b.length);
-	}
-
-	@Override
-	public void write(byte[] b, int off, int len) throws IOException {
-		try {
-			if(!tagWritten) writeTag();
-			if(betweenFrames) initialiseCipher();
-			byte[] ciphertext = frameCipher.update(b, off, len);
-			if(ciphertext != null) out.write(ciphertext);
+			out.write(b, off, len);
 			capacity -= len;
+			frame++;
 		} catch(IOException e) {
 			frameKey.erase();
 			throw e;
 		}
 	}
 
-	private void writeTag() throws IOException {
-		assert !tagWritten;
-		assert !betweenFrames;
-		out.write(tag);
-		capacity -= tag.length;
-		tagWritten = true;
-		betweenFrames = true;
+	public void flush() throws IOException {
+		out.flush();
 	}
 
-	private void initialiseCipher() {
-		assert tagWritten;
-		assert betweenFrames;
-		if(frame > MAX_32_BIT_UNSIGNED) throw new IllegalStateException();
-		IvEncoder.updateIv(iv, frame);
-		IvParameterSpec ivSpec = new IvParameterSpec(iv);
-		try {
-			frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec);
-		} catch(GeneralSecurityException badIvOrKey) {
-			throw new RuntimeException(badIvOrKey);
-		}
-		frame++;
-		betweenFrames = false;
+	public long getRemainingCapacity() {
+		return capacity;
 	}
 }
\ No newline at end of file
diff --git a/components/net/sf/briar/transport/ConnectionWriterImpl.java b/components/net/sf/briar/transport/ConnectionWriterImpl.java
index 673a7896be..eb617a2e66 100644
--- a/components/net/sf/briar/transport/ConnectionWriterImpl.java
+++ b/components/net/sf/briar/transport/ConnectionWriterImpl.java
@@ -4,13 +4,12 @@ import static net.sf.briar.api.transport.TransportConstants.FRAME_HEADER_LENGTH;
 import static net.sf.briar.api.transport.TransportConstants.MAX_FRAME_LENGTH;
 import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
 
-import java.io.ByteArrayOutputStream;
-import java.io.FilterOutputStream;
 import java.io.IOException;
 import java.io.OutputStream;
 import java.security.InvalidKeyException;
 
 import javax.crypto.Mac;
+import javax.crypto.ShortBufferException;
 
 import net.sf.briar.api.crypto.ErasableKey;
 import net.sf.briar.api.transport.ConnectionWriter;
@@ -21,20 +20,17 @@ import net.sf.briar.api.transport.ConnectionWriter;
  * <p>
  * This class is not thread-safe.
  */
-class ConnectionWriterImpl extends FilterOutputStream
-implements ConnectionWriter {
+class ConnectionWriterImpl extends OutputStream implements ConnectionWriter {
 
-	protected final ConnectionEncrypter encrypter;
-	protected final Mac mac;
-	protected final int maxPayloadLength;
-	protected final ByteArrayOutputStream buf;
-	protected final byte[] header;
+	private final ConnectionEncrypter encrypter;
+	private final Mac mac;
+	private final byte[] buf;
 
-	protected long frame = 0L;
+	private int bufLength = FRAME_HEADER_LENGTH;
+	private long frame = 0L;
 
 	ConnectionWriterImpl(ConnectionEncrypter encrypter, Mac mac,
 			ErasableKey macKey) {
-		super(encrypter.getOutputStream());
 		this.encrypter = encrypter;
 		this.mac = mac;
 		// Initialise the MAC
@@ -44,10 +40,7 @@ implements ConnectionWriter {
 			throw new IllegalArgumentException(badKey);
 		}
 		macKey.erase();
-		maxPayloadLength =
-			MAX_FRAME_LENGTH - FRAME_HEADER_LENGTH - mac.getMacLength();
-		buf = new ByteArrayOutputStream(maxPayloadLength);
-		header = new byte[FRAME_HEADER_LENGTH];
+		buf = new byte[MAX_FRAME_LENGTH];
 	}
 
 	public OutputStream getOutputStream() {
@@ -57,23 +50,24 @@ implements ConnectionWriter {
 	public long getRemainingCapacity() {
 		long capacity = encrypter.getRemainingCapacity();
 		// If there's any data buffered, subtract it and its auth overhead
-		int overheadPerFrame = header.length + mac.getMacLength();
-		if(buf.size() > 0) capacity -= buf.size() + overheadPerFrame;
+		if(bufLength > FRAME_HEADER_LENGTH)
+			capacity -= bufLength + mac.getMacLength();
 		// Subtract the auth overhead from the remaining capacity
 		long frames = (long) Math.ceil((double) capacity / MAX_FRAME_LENGTH);
+		int overheadPerFrame = FRAME_HEADER_LENGTH + mac.getMacLength();
 		return Math.max(0L, capacity - frames * overheadPerFrame);
 	}
 
 	@Override
 	public void flush() throws IOException {
-		if(buf.size() > 0) writeFrame();
-		out.flush();
+		if(bufLength > FRAME_HEADER_LENGTH) writeFrame();
+		encrypter.flush();
 	}
 
 	@Override
 	public void write(int b) throws IOException {
-		buf.write(b);
-		if(buf.size() == maxPayloadLength) writeFrame();
+		buf[bufLength++] = (byte) b;
+		if(bufLength + mac.getMacLength() == MAX_FRAME_LENGTH) writeFrame();
 	}
 
 	@Override
@@ -83,28 +77,32 @@ implements ConnectionWriter {
 
 	@Override
 	public void write(byte[] b, int off, int len) throws IOException {
-		int available = maxPayloadLength - buf.size();
+		int available = MAX_FRAME_LENGTH - bufLength - mac.getMacLength();
 		while(available <= len) {
-			buf.write(b, off, available);
+			System.arraycopy(b, off, buf, bufLength, available);
+			bufLength += available;
 			writeFrame();
 			off += available;
 			len -= available;
-			available = maxPayloadLength;
+			available = MAX_FRAME_LENGTH - bufLength - mac.getMacLength();
 		}
-		buf.write(b, off, len);
+		System.arraycopy(b, off, buf, bufLength, len);
+		bufLength += len;
 	}
 
 	private void writeFrame() throws IOException {
 		if(frame > MAX_32_BIT_UNSIGNED) throw new IllegalStateException();
-		byte[] payload = buf.toByteArray();
-		if(payload.length > maxPayloadLength) throw new IllegalStateException();
-		HeaderEncoder.encodeHeader(header, frame, payload.length, 0);
-		out.write(header);
-		mac.update(header);
-		out.write(payload);
-		mac.update(payload);
-		encrypter.writeFinal(mac.doFinal());
+		int payloadLength = bufLength - FRAME_HEADER_LENGTH;
+		assert payloadLength > 0;
+		HeaderEncoder.encodeHeader(buf, frame, payloadLength, 0);
+		mac.update(buf, 0, bufLength);
+		try {
+			mac.doFinal(buf, bufLength);
+		} catch(ShortBufferException badMac) {
+			throw new RuntimeException(badMac);
+		}
+		encrypter.writeFrame(buf, 0, bufLength + mac.getMacLength());
+		bufLength = FRAME_HEADER_LENGTH;
 		frame++;
-		buf.reset();
 	}
 }
diff --git a/test/net/sf/briar/transport/ConnectionEncrypterImplTest.java b/test/net/sf/briar/transport/ConnectionEncrypterImplTest.java
index 469c333cbd..7bb2996cfa 100644
--- a/test/net/sf/briar/transport/ConnectionEncrypterImplTest.java
+++ b/test/net/sf/briar/transport/ConnectionEncrypterImplTest.java
@@ -49,25 +49,16 @@ public class ConnectionEncrypterImplTest extends BriarTestCase {
 		byte[] tag = TagEncoder.encodeTag(0, tagCipher, tagKey);
 		// Calculate the expected ciphertext for the first frame
 		byte[] iv = new byte[frameCipher.getBlockSize()];
-		byte[] plaintext = new byte[123];
-		byte[] plaintextMac = new byte[MAC_LENGTH];
+		byte[] plaintext = new byte[123 + MAC_LENGTH];
 		IvParameterSpec ivSpec = new IvParameterSpec(iv);
 		frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec);
-		byte[] ciphertext = new byte[plaintext.length + plaintextMac.length];
-		int offset = frameCipher.update(plaintext, 0, plaintext.length,
-				ciphertext);
-		frameCipher.doFinal(plaintextMac, 0, plaintextMac.length, ciphertext,
-				offset);
+		byte[] ciphertext = frameCipher.doFinal(plaintext);
 		// Calculate the expected ciphertext for the second frame
-		byte[] plaintext1 = new byte[1234];
+		byte[] plaintext1 = new byte[1234 + MAC_LENGTH];
 		IvEncoder.updateIv(iv, 1L);
 		ivSpec = new IvParameterSpec(iv);
 		frameCipher.init(Cipher.ENCRYPT_MODE, frameKey, ivSpec);
-		byte[] ciphertext1 = new byte[plaintext1.length + plaintextMac.length];
-		offset = frameCipher.update(plaintext1, 0, plaintext1.length,
-				ciphertext1);
-		frameCipher.doFinal(plaintextMac, 0, plaintextMac.length, ciphertext1,
-				offset);
+		byte[] ciphertext1 = frameCipher.doFinal(plaintext1);
 		// Concatenate the ciphertexts
 		ByteArrayOutputStream out = new ByteArrayOutputStream();
 		out.write(tag);
@@ -78,10 +69,8 @@ public class ConnectionEncrypterImplTest extends BriarTestCase {
 		out.reset();
 		ConnectionEncrypter e = new ConnectionEncrypterImpl(out, Long.MAX_VALUE,
 				tagCipher, frameCipher, tagKey, frameKey);
-		e.getOutputStream().write(plaintext);
-		e.writeFinal(plaintextMac);
-		e.getOutputStream().write(plaintext1);
-		e.writeFinal(plaintextMac);
+		e.writeFrame(plaintext, 0, plaintext.length);
+		e.writeFrame(plaintext1, 0, plaintext1.length);
 		byte[] actual = out.toByteArray();
 		// Check that the actual ciphertext matches the expected ciphertext
 		assertArrayEquals(expected, actual);
diff --git a/test/net/sf/briar/transport/NullConnectionEncrypter.java b/test/net/sf/briar/transport/NullConnectionEncrypter.java
index 29ebcdd0e8..afcbb47765 100644
--- a/test/net/sf/briar/transport/NullConnectionEncrypter.java
+++ b/test/net/sf/briar/transport/NullConnectionEncrypter.java
@@ -1,51 +1,35 @@
 package net.sf.briar.transport;
 
-import java.io.FilterOutputStream;
 import java.io.IOException;
 import java.io.OutputStream;
 
 /** A ConnectionEncrypter that performs no encryption. */
-class NullConnectionEncrypter extends FilterOutputStream
-implements ConnectionEncrypter {
+class NullConnectionEncrypter implements ConnectionEncrypter {
+
+	private final OutputStream out;
 
 	private long capacity;
 
 	NullConnectionEncrypter(OutputStream out) {
-		this(out, Long.MAX_VALUE);
+		this.out = out;
+		capacity = Long.MAX_VALUE;
 	}
 
 	NullConnectionEncrypter(OutputStream out, long capacity) {
-		super(out);
+		this.out = out;
 		this.capacity = capacity;
 	}
 
-	public OutputStream getOutputStream() {
-		return this;
+	public void writeFrame(byte[] b, int off, int len) throws IOException {
+		out.write(b, off, len);
+		capacity -= len;
 	}
 
-	public void writeFinal(byte[] mac) throws IOException {
-		out.write(mac);
-		capacity -= mac.length;
+	public void flush() throws IOException {
+		out.flush();
 	}
 
 	public long getRemainingCapacity() {
 		return capacity;
 	}
-
-	@Override
-	public void write(int b) throws IOException {
-		out.write(b);
-		capacity--;
-	}
-
-	@Override
-	public void write(byte[] b) throws IOException {
-		write(b, 0, b.length);
-	}
-
-	@Override
-	public void write(byte[] b, int off, int len) throws IOException {
-		out.write(b, off, len);
-		capacity -= len;
-	}
 }
-- 
GitLab