From a8b96f11fd0fc37f46783c566597e27dc6459c7b Mon Sep 17 00:00:00 2001
From: akwizgran <akwizgran@users.sourceforge.net>
Date: Wed, 28 Sep 2011 18:47:24 +0100
Subject: [PATCH] Added Consumer support to Writer, to avoid redundant copying.

---
 api/net/sf/briar/api/serial/Consumer.java     |   2 -
 api/net/sf/briar/api/serial/Writer.java       |   3 +
 .../sf/briar/protocol/CopyingConsumer.java    |   4 -
 .../sf/briar/protocol/CountingConsumer.java   |   5 -
 .../sf/briar/protocol/DigestingConsumer.java  |   4 -
 .../sf/briar/protocol/MessageEncoderImpl.java |  36 ++++--
 .../sf/briar/protocol/SigningConsumer.java    |  33 ++++++
 .../net/sf/briar/serial/ReaderImpl.java       |  19 +---
 .../net/sf/briar/serial/WriterImpl.java       | 105 +++++++++++-------
 .../net/sf/briar/transport/MacConsumer.java   |   4 -
 test/net/sf/briar/FileReadWriteTest.java      |  12 +-
 test/net/sf/briar/serial/ReaderImplTest.java  |   4 -
 12 files changed, 134 insertions(+), 97 deletions(-)
 create mode 100644 components/net/sf/briar/protocol/SigningConsumer.java

diff --git a/api/net/sf/briar/api/serial/Consumer.java b/api/net/sf/briar/api/serial/Consumer.java
index 4091648d0f..4d08ceb382 100644
--- a/api/net/sf/briar/api/serial/Consumer.java
+++ b/api/net/sf/briar/api/serial/Consumer.java
@@ -6,7 +6,5 @@ public interface Consumer {
 
 	void write(byte b) throws IOException;
 
-	void write(byte[] b) throws IOException;
-
 	void write(byte[] b, int off, int len) throws IOException;
 }
diff --git a/api/net/sf/briar/api/serial/Writer.java b/api/net/sf/briar/api/serial/Writer.java
index 887c1e51b3..9b2b1482d3 100644
--- a/api/net/sf/briar/api/serial/Writer.java
+++ b/api/net/sf/briar/api/serial/Writer.java
@@ -6,6 +6,9 @@ import java.util.Map;
 
 public interface Writer {
 
+	void addConsumer(Consumer c);
+	void removeConsumer(Consumer c);
+
 	void writeBoolean(boolean b) throws IOException;
 
 	void writeUint7(byte b) throws IOException;
diff --git a/components/net/sf/briar/protocol/CopyingConsumer.java b/components/net/sf/briar/protocol/CopyingConsumer.java
index 6e84d5929a..97129ab327 100644
--- a/components/net/sf/briar/protocol/CopyingConsumer.java
+++ b/components/net/sf/briar/protocol/CopyingConsumer.java
@@ -18,10 +18,6 @@ class CopyingConsumer implements Consumer {
 		out.write(b);
 	}
 
-	public void write(byte[] b) throws IOException {
-		out.write(b);
-	}
-
 	public void write(byte[] b, int off, int len) throws IOException {
 		out.write(b, off, len);
 	}
diff --git a/components/net/sf/briar/protocol/CountingConsumer.java b/components/net/sf/briar/protocol/CountingConsumer.java
index bb032341e0..7722ce5437 100644
--- a/components/net/sf/briar/protocol/CountingConsumer.java
+++ b/components/net/sf/briar/protocol/CountingConsumer.java
@@ -27,11 +27,6 @@ class CountingConsumer implements Consumer {
 		if(count > limit) throw new FormatException();
 	}
 
-	public void write(byte[] b) throws IOException {
-		count += b.length;
-		if(count > limit) throw new FormatException();
-	}
-
 	public void write(byte[] b, int off, int len) throws IOException {
 		count += len;
 		if(count > limit) throw new FormatException();
diff --git a/components/net/sf/briar/protocol/DigestingConsumer.java b/components/net/sf/briar/protocol/DigestingConsumer.java
index 78826ff45b..29a0c260ef 100644
--- a/components/net/sf/briar/protocol/DigestingConsumer.java
+++ b/components/net/sf/briar/protocol/DigestingConsumer.java
@@ -17,10 +17,6 @@ class DigestingConsumer implements Consumer {
 		messageDigest.update(b);
 	}
 
-	public void write(byte[] b) {
-		messageDigest.update(b);
-	}
-
 	public void write(byte[] b, int off, int len) {
 		messageDigest.update(b, off, len);
 	}
diff --git a/components/net/sf/briar/protocol/MessageEncoderImpl.java b/components/net/sf/briar/protocol/MessageEncoderImpl.java
index 833bc3e418..acf1aeebf0 100644
--- a/components/net/sf/briar/protocol/MessageEncoderImpl.java
+++ b/components/net/sf/briar/protocol/MessageEncoderImpl.java
@@ -17,6 +17,7 @@ import net.sf.briar.api.protocol.Message;
 import net.sf.briar.api.protocol.MessageEncoder;
 import net.sf.briar.api.protocol.MessageId;
 import net.sf.briar.api.protocol.Types;
+import net.sf.briar.api.serial.Consumer;
 import net.sf.briar.api.serial.Writer;
 import net.sf.briar.api.serial.WriterFactory;
 
@@ -24,14 +25,15 @@ import com.google.inject.Inject;
 
 class MessageEncoderImpl implements MessageEncoder {
 
-	private final Signature signature;
+	private final Signature authorSignature, groupSignature;
 	private final SecureRandom random;
 	private final MessageDigest messageDigest;
 	private final WriterFactory writerFactory;
 
 	@Inject
 	MessageEncoderImpl(CryptoComponent crypto, WriterFactory writerFactory) {
-		signature = crypto.getSignature();
+		authorSignature = crypto.getSignature();
+		groupSignature = crypto.getSignature();
 		random = crypto.getSecureRandom();
 		messageDigest = crypto.getMessageDigest();
 		this.writerFactory = writerFactory;
@@ -71,9 +73,23 @@ class MessageEncoderImpl implements MessageEncoder {
 		if(body.length > Message.MAX_BODY_LENGTH)
 			throw new IllegalArgumentException();
 
-		long timestamp = System.currentTimeMillis();
 		ByteArrayOutputStream out = new ByteArrayOutputStream();
 		Writer w = writerFactory.createWriter(out);
+		// Initialise the consumers
+		Consumer digestingConsumer = new DigestingConsumer(messageDigest);
+		w.addConsumer(digestingConsumer);
+		Consumer authorConsumer = null;
+		if(authorKey != null) {
+			authorSignature.initSign(authorKey);
+			authorConsumer = new SigningConsumer(authorSignature);
+			w.addConsumer(authorConsumer);
+		}
+		Consumer groupConsumer = null;
+		if(groupKey != null) {
+			groupSignature.initSign(groupKey);
+			groupConsumer = new SigningConsumer(groupSignature);
+			w.addConsumer(groupConsumer);
+		}
 		// Write the message
 		w.writeUserDefinedId(Types.MESSAGE);
 		if(parent == null) w.writeNull();
@@ -82,6 +98,7 @@ class MessageEncoderImpl implements MessageEncoder {
 		else group.writeTo(w);
 		if(author == null) w.writeNull();
 		else author.writeTo(w);
+		long timestamp = System.currentTimeMillis();
 		w.writeInt64(timestamp);
 		byte[] salt = new byte[Message.SALT_LENGTH];
 		random.nextBytes(salt);
@@ -91,9 +108,8 @@ class MessageEncoderImpl implements MessageEncoder {
 		if(authorKey == null) {
 			w.writeNull();
 		} else {
-			signature.initSign(authorKey);
-			signature.update(out.toByteArray());
-			byte[] sig = signature.sign();
+			w.removeConsumer(authorConsumer);
+			byte[] sig = authorSignature.sign();
 			if(sig.length > Message.MAX_SIGNATURE_LENGTH)
 				throw new IllegalArgumentException();
 			w.writeBytes(sig);
@@ -102,17 +118,15 @@ class MessageEncoderImpl implements MessageEncoder {
 		if(groupKey == null) {
 			w.writeNull();
 		} else {
-			signature.initSign(groupKey);
-			signature.update(out.toByteArray());
-			byte[] sig = signature.sign();
+			w.removeConsumer(groupConsumer);
+			byte[] sig = groupSignature.sign();
 			if(sig.length > Message.MAX_SIGNATURE_LENGTH)
 				throw new IllegalArgumentException();
 			w.writeBytes(sig);
 		}
 		// Hash the message, including the signatures, to get the message ID
+		w.removeConsumer(digestingConsumer);
 		byte[] raw = out.toByteArray();
-		messageDigest.reset();
-		messageDigest.update(raw);
 		MessageId id = new MessageId(messageDigest.digest());
 		GroupId groupId = group == null ? null : group.getId();
 		AuthorId authorId = author == null ? null : author.getId();
diff --git a/components/net/sf/briar/protocol/SigningConsumer.java b/components/net/sf/briar/protocol/SigningConsumer.java
new file mode 100644
index 0000000000..c3274ba7ff
--- /dev/null
+++ b/components/net/sf/briar/protocol/SigningConsumer.java
@@ -0,0 +1,33 @@
+package net.sf.briar.protocol;
+
+import java.io.IOException;
+import java.security.Signature;
+import java.security.SignatureException;
+
+import net.sf.briar.api.serial.Consumer;
+
+/** A consumer that passes its input through a signature. */
+class SigningConsumer implements Consumer {
+
+	private final Signature signature;
+
+	SigningConsumer(Signature signature) {
+		this.signature = signature;
+	}
+
+	public void write(byte b) throws IOException {
+		try {
+			signature.update(b);
+		} catch(SignatureException e) {
+			throw new IOException(e.getMessage());
+		}
+	}
+
+	public void write(byte[] b, int off, int len) throws IOException {
+		try {
+			signature.update(b, off, len);
+		} catch(SignatureException e) {
+			throw new IOException(e.getMessage());
+		}
+	}
+}
diff --git a/components/net/sf/briar/serial/ReaderImpl.java b/components/net/sf/briar/serial/ReaderImpl.java
index 600e0abdf0..acbf2e4922 100644
--- a/components/net/sf/briar/serial/ReaderImpl.java
+++ b/components/net/sf/briar/serial/ReaderImpl.java
@@ -19,8 +19,8 @@ class ReaderImpl implements Reader {
 	private static final byte[] EMPTY_BUFFER = new byte[] {};
 
 	private final InputStream in;
+	private final List<Consumer> consumers = new ArrayList<Consumer>(0);
 
-	private Consumer[] consumers = new Consumer[] {};
 	private ObjectReader<?>[] objectReaders = new ObjectReader<?>[] {};
 	private boolean hasLookahead = false, eof = false;
 	private byte next, nextNext;
@@ -89,24 +89,11 @@ class ReaderImpl implements Reader {
 	}
 
 	public void addConsumer(Consumer c) {
-		Consumer[] newConsumers = new Consumer[consumers.length + 1];
-		System.arraycopy(consumers, 0, newConsumers, 0, consumers.length);
-		newConsumers[consumers.length] = c;
-		consumers = newConsumers;
+		consumers.add(c);
 	}
 
 	public void removeConsumer(Consumer c) {
-		if(consumers.length == 0) throw new IllegalArgumentException();
-		Consumer[] newConsumers = new Consumer[consumers.length - 1];
-		boolean found = false;
-		for(int src = 0, dest = 0; src < consumers.length; src++, dest++) {
-			if(!found && consumers[src].equals(c)) {
-				found = true;
-				src++;
-			} else newConsumers[dest] = consumers[src];
-		}
-		if(found) consumers = newConsumers;
-		else throw new IllegalArgumentException();
+		if(!consumers.remove(c)) throw new IllegalArgumentException();
 	}
 
 	public void addObjectReader(int id, ObjectReader<?> o) {
diff --git a/components/net/sf/briar/serial/WriterImpl.java b/components/net/sf/briar/serial/WriterImpl.java
index ac786c57fd..d491a6a5ae 100644
--- a/components/net/sf/briar/serial/WriterImpl.java
+++ b/components/net/sf/briar/serial/WriterImpl.java
@@ -2,70 +2,81 @@ package net.sf.briar.serial;
 
 import java.io.IOException;
 import java.io.OutputStream;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
 
 import net.sf.briar.api.Bytes;
+import net.sf.briar.api.serial.Consumer;
 import net.sf.briar.api.serial.Writable;
 import net.sf.briar.api.serial.Writer;
 
 class WriterImpl implements Writer {
 
 	private final OutputStream out;
+	private final List<Consumer> consumers = new ArrayList<Consumer>(0);
 
 	WriterImpl(OutputStream out) {
 		this.out = out;
 	}
 
+	public void addConsumer(Consumer c) {
+		consumers.add(c);
+	}
+
+	public void removeConsumer(Consumer c) {
+		if(!consumers.remove(c)) throw new IllegalArgumentException();
+	}
+
 	public void writeBoolean(boolean b) throws IOException {
-		if(b) out.write(Tag.TRUE);
-		else out.write(Tag.FALSE);
+		if(b) write(Tag.TRUE);
+		else write(Tag.FALSE);
 	}
 
 	public void writeUint7(byte b) throws IOException {
 		if(b < 0) throw new IllegalArgumentException();
-		out.write(b);
+		write(b);
 	}
 
 	public void writeInt8(byte b) throws IOException {
-		out.write(Tag.INT8);
-		out.write(b);
+		write(Tag.INT8);
+		write(b);
 	}
 
 	public void writeInt16(short s) throws IOException {
-		out.write(Tag.INT16);
-		out.write((byte) (s >> 8));
-		out.write((byte) ((s << 8) >> 8));
+		write(Tag.INT16);
+		write((byte) (s >> 8));
+		write((byte) ((s << 8) >> 8));
 	}
 
 	public void writeInt32(int i) throws IOException {
-		out.write(Tag.INT32);
+		write(Tag.INT32);
 		writeInt32Bits(i);
 	}
 
 	private void writeInt32Bits(int i) throws IOException {
-		out.write((byte) (i >> 24));
-		out.write((byte) ((i << 8) >> 24));
-		out.write((byte) ((i << 16) >> 24));
-		out.write((byte) ((i << 24) >> 24));
+		write((byte) (i >> 24));
+		write((byte) ((i << 8) >> 24));
+		write((byte) ((i << 16) >> 24));
+		write((byte) ((i << 24) >> 24));
 	}
 
 	public void writeInt64(long l) throws IOException {
-		out.write(Tag.INT64);
+		write(Tag.INT64);
 		writeInt64Bits(l);
 	}
 
 	private void writeInt64Bits(long l) throws IOException {
-		out.write((byte) (l >> 56));
-		out.write((byte) ((l << 8) >> 56));
-		out.write((byte) ((l << 16) >> 56));
-		out.write((byte) ((l << 24) >> 56));
-		out.write((byte) ((l << 32) >> 56));
-		out.write((byte) ((l << 40) >> 56));
-		out.write((byte) ((l << 48) >> 56));
-		out.write((byte) ((l << 56) >> 56));
+		write((byte) (l >> 56));
+		write((byte) ((l << 8) >> 56));
+		write((byte) ((l << 16) >> 56));
+		write((byte) ((l << 24) >> 56));
+		write((byte) ((l << 32) >> 56));
+		write((byte) ((l << 40) >> 56));
+		write((byte) ((l << 48) >> 56));
+		write((byte) ((l << 56) >> 56));
 	}
 
 	public void writeIntAny(long l) throws IOException {
@@ -81,23 +92,23 @@ class WriterImpl implements Writer {
 	}
 
 	public void writeFloat32(float f) throws IOException {
-		out.write(Tag.FLOAT32);
+		write(Tag.FLOAT32);
 		writeInt32Bits(Float.floatToRawIntBits(f));
 	}
 
 	public void writeFloat64(double d) throws IOException {
-		out.write(Tag.FLOAT64);
+		write(Tag.FLOAT64);
 		writeInt64Bits(Double.doubleToRawLongBits(d));
 	}
 
 	public void writeString(String s) throws IOException {
 		byte[] b = s.getBytes("UTF-8");
-		if(b.length < 16) out.write((byte) (Tag.SHORT_STRING | b.length));
+		if(b.length < 16) write((byte) (Tag.SHORT_STRING | b.length));
 		else {
-			out.write(Tag.STRING);
+			write(Tag.STRING);
 			writeLength(b.length);
 		}
-		out.write(b);
+		write(b);
 	}
 
 	private void writeLength(int i) throws IOException {
@@ -109,19 +120,19 @@ class WriterImpl implements Writer {
 	}
 
 	public void writeBytes(byte[] b) throws IOException {
-		if(b.length < 16) out.write((byte) (Tag.SHORT_BYTES | b.length));
+		if(b.length < 16) write((byte) (Tag.SHORT_BYTES | b.length));
 		else {
-			out.write(Tag.BYTES);
+			write(Tag.BYTES);
 			writeLength(b.length);
 		}
-		out.write(b);
+		write(b);
 	}
 
 	public void writeList(Collection<?> c) throws IOException {
 		int length = c.size();
-		if(length < 16) out.write((byte) (Tag.SHORT_LIST | length));
+		if(length < 16) write((byte) (Tag.SHORT_LIST | length));
 		else {
-			out.write(Tag.LIST);
+			write(Tag.LIST);
 			writeLength(length);
 		}
 		for(Object o : c) writeObject(o);
@@ -145,18 +156,18 @@ class WriterImpl implements Writer {
 	}
 
 	public void writeListStart() throws IOException {
-		out.write(Tag.LIST_START);
+		write(Tag.LIST_START);
 	}
 
 	public void writeListEnd() throws IOException {
-		out.write(Tag.END);
+		write(Tag.END);
 	}
 
 	public void writeMap(Map<?, ?> m) throws IOException {
 		int length = m.size();
-		if(length < 16) out.write((byte) (Tag.SHORT_MAP | length));
+		if(length < 16) write((byte) (Tag.SHORT_MAP | length));
 		else {
-			out.write(Tag.MAP);
+			write(Tag.MAP);
 			writeLength(length);
 		}
 		for(Entry<?, ?> e : m.entrySet()) {
@@ -166,24 +177,34 @@ class WriterImpl implements Writer {
 	}
 
 	public void writeMapStart() throws IOException {
-		out.write(Tag.MAP_START);
+		write(Tag.MAP_START);
 	}
 
 	public void writeMapEnd() throws IOException {
-		out.write(Tag.END);
+		write(Tag.END);
 	}
 
 	public void writeNull() throws IOException {
-		out.write(Tag.NULL);
+		write(Tag.NULL);
 	}
 
 	public void writeUserDefinedId(int id) throws IOException {
 		if(id < 0 || id > 255) throw new IllegalArgumentException();
 		if(id < 32) {
-			out.write((byte) (Tag.SHORT_USER | id));
+			write((byte) (Tag.SHORT_USER | id));
 		} else {
-			out.write(Tag.USER);
-			out.write((byte) id);
+			write(Tag.USER);
+			write((byte) id);
 		}
 	}
+
+	private void write(byte b) throws IOException {
+		out.write(b);
+		for(Consumer c : consumers) c.write(b);
+	}
+
+	private void write(byte[] b) throws IOException {
+		out.write(b);
+		for(Consumer c : consumers) c.write(b, 0, b.length);
+	}
 }
diff --git a/components/net/sf/briar/transport/MacConsumer.java b/components/net/sf/briar/transport/MacConsumer.java
index dbf20af8d1..f494cf2a55 100644
--- a/components/net/sf/briar/transport/MacConsumer.java
+++ b/components/net/sf/briar/transport/MacConsumer.java
@@ -17,10 +17,6 @@ class MacConsumer implements Consumer {
 		mac.update(b);
 	}
 
-	public void write(byte[] b) {
-		mac.update(b);
-	}
-
 	public void write(byte[] b, int off, int len) {
 		mac.update(b, off, len);
 	}
diff --git a/test/net/sf/briar/FileReadWriteTest.java b/test/net/sf/briar/FileReadWriteTest.java
index a624fb2bcb..12c5ebb935 100644
--- a/test/net/sf/briar/FileReadWriteTest.java
+++ b/test/net/sf/briar/FileReadWriteTest.java
@@ -130,7 +130,12 @@ public class FileReadWriteTest extends TestCase {
 	}
 
 	@Test
-	public void testWriteFile() throws Exception {
+	public void testWriteAndRead() throws Exception {
+		write();
+		read();
+	}
+
+	private void write() throws Exception {
 		OutputStream out = new FileOutputStream(file);
 		// Use Alice's secret for writing
 		ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out,
@@ -177,10 +182,7 @@ public class FileReadWriteTest extends TestCase {
 		assertTrue(file.length() > message.getSize());
 	}
 
-	@Test
-	public void testWriteAndReadFile() throws Exception {
-
-		testWriteFile();
+	private void read() throws Exception {
 
 		InputStream in = new FileInputStream(file);
 		byte[] iv = new byte[16];
diff --git a/test/net/sf/briar/serial/ReaderImplTest.java b/test/net/sf/briar/serial/ReaderImplTest.java
index 982f0f397a..539d0da49c 100644
--- a/test/net/sf/briar/serial/ReaderImplTest.java
+++ b/test/net/sf/briar/serial/ReaderImplTest.java
@@ -416,10 +416,6 @@ public class ReaderImplTest extends TestCase {
 				out.write(b);
 			}
 
-			public void write(byte[] b) throws IOException {
-				out.write(b);
-			}
-
 			public void write(byte[] b, int off, int len) throws IOException {
 				out.write(b, off, len);
 			}
-- 
GitLab