From a8b96f11fd0fc37f46783c566597e27dc6459c7b Mon Sep 17 00:00:00 2001 From: akwizgran <akwizgran@users.sourceforge.net> Date: Wed, 28 Sep 2011 18:47:24 +0100 Subject: [PATCH] Added Consumer support to Writer, to avoid redundant copying. --- api/net/sf/briar/api/serial/Consumer.java | 2 - api/net/sf/briar/api/serial/Writer.java | 3 + .../sf/briar/protocol/CopyingConsumer.java | 4 - .../sf/briar/protocol/CountingConsumer.java | 5 - .../sf/briar/protocol/DigestingConsumer.java | 4 - .../sf/briar/protocol/MessageEncoderImpl.java | 36 ++++-- .../sf/briar/protocol/SigningConsumer.java | 33 ++++++ .../net/sf/briar/serial/ReaderImpl.java | 19 +--- .../net/sf/briar/serial/WriterImpl.java | 105 +++++++++++------- .../net/sf/briar/transport/MacConsumer.java | 4 - test/net/sf/briar/FileReadWriteTest.java | 12 +- test/net/sf/briar/serial/ReaderImplTest.java | 4 - 12 files changed, 134 insertions(+), 97 deletions(-) create mode 100644 components/net/sf/briar/protocol/SigningConsumer.java diff --git a/api/net/sf/briar/api/serial/Consumer.java b/api/net/sf/briar/api/serial/Consumer.java index 4091648d0f..4d08ceb382 100644 --- a/api/net/sf/briar/api/serial/Consumer.java +++ b/api/net/sf/briar/api/serial/Consumer.java @@ -6,7 +6,5 @@ public interface Consumer { void write(byte b) throws IOException; - void write(byte[] b) throws IOException; - void write(byte[] b, int off, int len) throws IOException; } diff --git a/api/net/sf/briar/api/serial/Writer.java b/api/net/sf/briar/api/serial/Writer.java index 887c1e51b3..9b2b1482d3 100644 --- a/api/net/sf/briar/api/serial/Writer.java +++ b/api/net/sf/briar/api/serial/Writer.java @@ -6,6 +6,9 @@ import java.util.Map; public interface Writer { + void addConsumer(Consumer c); + void removeConsumer(Consumer c); + void writeBoolean(boolean b) throws IOException; void writeUint7(byte b) throws IOException; diff --git a/components/net/sf/briar/protocol/CopyingConsumer.java b/components/net/sf/briar/protocol/CopyingConsumer.java index 6e84d5929a..97129ab327 100644 --- a/components/net/sf/briar/protocol/CopyingConsumer.java +++ b/components/net/sf/briar/protocol/CopyingConsumer.java @@ -18,10 +18,6 @@ class CopyingConsumer implements Consumer { out.write(b); } - public void write(byte[] b) throws IOException { - out.write(b); - } - public void write(byte[] b, int off, int len) throws IOException { out.write(b, off, len); } diff --git a/components/net/sf/briar/protocol/CountingConsumer.java b/components/net/sf/briar/protocol/CountingConsumer.java index bb032341e0..7722ce5437 100644 --- a/components/net/sf/briar/protocol/CountingConsumer.java +++ b/components/net/sf/briar/protocol/CountingConsumer.java @@ -27,11 +27,6 @@ class CountingConsumer implements Consumer { if(count > limit) throw new FormatException(); } - public void write(byte[] b) throws IOException { - count += b.length; - if(count > limit) throw new FormatException(); - } - public void write(byte[] b, int off, int len) throws IOException { count += len; if(count > limit) throw new FormatException(); diff --git a/components/net/sf/briar/protocol/DigestingConsumer.java b/components/net/sf/briar/protocol/DigestingConsumer.java index 78826ff45b..29a0c260ef 100644 --- a/components/net/sf/briar/protocol/DigestingConsumer.java +++ b/components/net/sf/briar/protocol/DigestingConsumer.java @@ -17,10 +17,6 @@ class DigestingConsumer implements Consumer { messageDigest.update(b); } - public void write(byte[] b) { - messageDigest.update(b); - } - public void write(byte[] b, int off, int len) { messageDigest.update(b, off, len); } diff --git a/components/net/sf/briar/protocol/MessageEncoderImpl.java b/components/net/sf/briar/protocol/MessageEncoderImpl.java index 833bc3e418..acf1aeebf0 100644 --- a/components/net/sf/briar/protocol/MessageEncoderImpl.java +++ b/components/net/sf/briar/protocol/MessageEncoderImpl.java @@ -17,6 +17,7 @@ import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.MessageEncoder; import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.Types; +import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Writer; import net.sf.briar.api.serial.WriterFactory; @@ -24,14 +25,15 @@ import com.google.inject.Inject; class MessageEncoderImpl implements MessageEncoder { - private final Signature signature; + private final Signature authorSignature, groupSignature; private final SecureRandom random; private final MessageDigest messageDigest; private final WriterFactory writerFactory; @Inject MessageEncoderImpl(CryptoComponent crypto, WriterFactory writerFactory) { - signature = crypto.getSignature(); + authorSignature = crypto.getSignature(); + groupSignature = crypto.getSignature(); random = crypto.getSecureRandom(); messageDigest = crypto.getMessageDigest(); this.writerFactory = writerFactory; @@ -71,9 +73,23 @@ class MessageEncoderImpl implements MessageEncoder { if(body.length > Message.MAX_BODY_LENGTH) throw new IllegalArgumentException(); - long timestamp = System.currentTimeMillis(); ByteArrayOutputStream out = new ByteArrayOutputStream(); Writer w = writerFactory.createWriter(out); + // Initialise the consumers + Consumer digestingConsumer = new DigestingConsumer(messageDigest); + w.addConsumer(digestingConsumer); + Consumer authorConsumer = null; + if(authorKey != null) { + authorSignature.initSign(authorKey); + authorConsumer = new SigningConsumer(authorSignature); + w.addConsumer(authorConsumer); + } + Consumer groupConsumer = null; + if(groupKey != null) { + groupSignature.initSign(groupKey); + groupConsumer = new SigningConsumer(groupSignature); + w.addConsumer(groupConsumer); + } // Write the message w.writeUserDefinedId(Types.MESSAGE); if(parent == null) w.writeNull(); @@ -82,6 +98,7 @@ class MessageEncoderImpl implements MessageEncoder { else group.writeTo(w); if(author == null) w.writeNull(); else author.writeTo(w); + long timestamp = System.currentTimeMillis(); w.writeInt64(timestamp); byte[] salt = new byte[Message.SALT_LENGTH]; random.nextBytes(salt); @@ -91,9 +108,8 @@ class MessageEncoderImpl implements MessageEncoder { if(authorKey == null) { w.writeNull(); } else { - signature.initSign(authorKey); - signature.update(out.toByteArray()); - byte[] sig = signature.sign(); + w.removeConsumer(authorConsumer); + byte[] sig = authorSignature.sign(); if(sig.length > Message.MAX_SIGNATURE_LENGTH) throw new IllegalArgumentException(); w.writeBytes(sig); @@ -102,17 +118,15 @@ class MessageEncoderImpl implements MessageEncoder { if(groupKey == null) { w.writeNull(); } else { - signature.initSign(groupKey); - signature.update(out.toByteArray()); - byte[] sig = signature.sign(); + w.removeConsumer(groupConsumer); + byte[] sig = groupSignature.sign(); if(sig.length > Message.MAX_SIGNATURE_LENGTH) throw new IllegalArgumentException(); w.writeBytes(sig); } // Hash the message, including the signatures, to get the message ID + w.removeConsumer(digestingConsumer); byte[] raw = out.toByteArray(); - messageDigest.reset(); - messageDigest.update(raw); MessageId id = new MessageId(messageDigest.digest()); GroupId groupId = group == null ? null : group.getId(); AuthorId authorId = author == null ? null : author.getId(); diff --git a/components/net/sf/briar/protocol/SigningConsumer.java b/components/net/sf/briar/protocol/SigningConsumer.java new file mode 100644 index 0000000000..c3274ba7ff --- /dev/null +++ b/components/net/sf/briar/protocol/SigningConsumer.java @@ -0,0 +1,33 @@ +package net.sf.briar.protocol; + +import java.io.IOException; +import java.security.Signature; +import java.security.SignatureException; + +import net.sf.briar.api.serial.Consumer; + +/** A consumer that passes its input through a signature. */ +class SigningConsumer implements Consumer { + + private final Signature signature; + + SigningConsumer(Signature signature) { + this.signature = signature; + } + + public void write(byte b) throws IOException { + try { + signature.update(b); + } catch(SignatureException e) { + throw new IOException(e.getMessage()); + } + } + + public void write(byte[] b, int off, int len) throws IOException { + try { + signature.update(b, off, len); + } catch(SignatureException e) { + throw new IOException(e.getMessage()); + } + } +} diff --git a/components/net/sf/briar/serial/ReaderImpl.java b/components/net/sf/briar/serial/ReaderImpl.java index 600e0abdf0..acbf2e4922 100644 --- a/components/net/sf/briar/serial/ReaderImpl.java +++ b/components/net/sf/briar/serial/ReaderImpl.java @@ -19,8 +19,8 @@ class ReaderImpl implements Reader { private static final byte[] EMPTY_BUFFER = new byte[] {}; private final InputStream in; + private final List<Consumer> consumers = new ArrayList<Consumer>(0); - private Consumer[] consumers = new Consumer[] {}; private ObjectReader<?>[] objectReaders = new ObjectReader<?>[] {}; private boolean hasLookahead = false, eof = false; private byte next, nextNext; @@ -89,24 +89,11 @@ class ReaderImpl implements Reader { } public void addConsumer(Consumer c) { - Consumer[] newConsumers = new Consumer[consumers.length + 1]; - System.arraycopy(consumers, 0, newConsumers, 0, consumers.length); - newConsumers[consumers.length] = c; - consumers = newConsumers; + consumers.add(c); } public void removeConsumer(Consumer c) { - if(consumers.length == 0) throw new IllegalArgumentException(); - Consumer[] newConsumers = new Consumer[consumers.length - 1]; - boolean found = false; - for(int src = 0, dest = 0; src < consumers.length; src++, dest++) { - if(!found && consumers[src].equals(c)) { - found = true; - src++; - } else newConsumers[dest] = consumers[src]; - } - if(found) consumers = newConsumers; - else throw new IllegalArgumentException(); + if(!consumers.remove(c)) throw new IllegalArgumentException(); } public void addObjectReader(int id, ObjectReader<?> o) { diff --git a/components/net/sf/briar/serial/WriterImpl.java b/components/net/sf/briar/serial/WriterImpl.java index ac786c57fd..d491a6a5ae 100644 --- a/components/net/sf/briar/serial/WriterImpl.java +++ b/components/net/sf/briar/serial/WriterImpl.java @@ -2,70 +2,81 @@ package net.sf.briar.serial; import java.io.IOException; import java.io.OutputStream; +import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Map.Entry; import net.sf.briar.api.Bytes; +import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Writable; import net.sf.briar.api.serial.Writer; class WriterImpl implements Writer { private final OutputStream out; + private final List<Consumer> consumers = new ArrayList<Consumer>(0); WriterImpl(OutputStream out) { this.out = out; } + public void addConsumer(Consumer c) { + consumers.add(c); + } + + public void removeConsumer(Consumer c) { + if(!consumers.remove(c)) throw new IllegalArgumentException(); + } + public void writeBoolean(boolean b) throws IOException { - if(b) out.write(Tag.TRUE); - else out.write(Tag.FALSE); + if(b) write(Tag.TRUE); + else write(Tag.FALSE); } public void writeUint7(byte b) throws IOException { if(b < 0) throw new IllegalArgumentException(); - out.write(b); + write(b); } public void writeInt8(byte b) throws IOException { - out.write(Tag.INT8); - out.write(b); + write(Tag.INT8); + write(b); } public void writeInt16(short s) throws IOException { - out.write(Tag.INT16); - out.write((byte) (s >> 8)); - out.write((byte) ((s << 8) >> 8)); + write(Tag.INT16); + write((byte) (s >> 8)); + write((byte) ((s << 8) >> 8)); } public void writeInt32(int i) throws IOException { - out.write(Tag.INT32); + write(Tag.INT32); writeInt32Bits(i); } private void writeInt32Bits(int i) throws IOException { - out.write((byte) (i >> 24)); - out.write((byte) ((i << 8) >> 24)); - out.write((byte) ((i << 16) >> 24)); - out.write((byte) ((i << 24) >> 24)); + write((byte) (i >> 24)); + write((byte) ((i << 8) >> 24)); + write((byte) ((i << 16) >> 24)); + write((byte) ((i << 24) >> 24)); } public void writeInt64(long l) throws IOException { - out.write(Tag.INT64); + write(Tag.INT64); writeInt64Bits(l); } private void writeInt64Bits(long l) throws IOException { - out.write((byte) (l >> 56)); - out.write((byte) ((l << 8) >> 56)); - out.write((byte) ((l << 16) >> 56)); - out.write((byte) ((l << 24) >> 56)); - out.write((byte) ((l << 32) >> 56)); - out.write((byte) ((l << 40) >> 56)); - out.write((byte) ((l << 48) >> 56)); - out.write((byte) ((l << 56) >> 56)); + write((byte) (l >> 56)); + write((byte) ((l << 8) >> 56)); + write((byte) ((l << 16) >> 56)); + write((byte) ((l << 24) >> 56)); + write((byte) ((l << 32) >> 56)); + write((byte) ((l << 40) >> 56)); + write((byte) ((l << 48) >> 56)); + write((byte) ((l << 56) >> 56)); } public void writeIntAny(long l) throws IOException { @@ -81,23 +92,23 @@ class WriterImpl implements Writer { } public void writeFloat32(float f) throws IOException { - out.write(Tag.FLOAT32); + write(Tag.FLOAT32); writeInt32Bits(Float.floatToRawIntBits(f)); } public void writeFloat64(double d) throws IOException { - out.write(Tag.FLOAT64); + write(Tag.FLOAT64); writeInt64Bits(Double.doubleToRawLongBits(d)); } public void writeString(String s) throws IOException { byte[] b = s.getBytes("UTF-8"); - if(b.length < 16) out.write((byte) (Tag.SHORT_STRING | b.length)); + if(b.length < 16) write((byte) (Tag.SHORT_STRING | b.length)); else { - out.write(Tag.STRING); + write(Tag.STRING); writeLength(b.length); } - out.write(b); + write(b); } private void writeLength(int i) throws IOException { @@ -109,19 +120,19 @@ class WriterImpl implements Writer { } public void writeBytes(byte[] b) throws IOException { - if(b.length < 16) out.write((byte) (Tag.SHORT_BYTES | b.length)); + if(b.length < 16) write((byte) (Tag.SHORT_BYTES | b.length)); else { - out.write(Tag.BYTES); + write(Tag.BYTES); writeLength(b.length); } - out.write(b); + write(b); } public void writeList(Collection<?> c) throws IOException { int length = c.size(); - if(length < 16) out.write((byte) (Tag.SHORT_LIST | length)); + if(length < 16) write((byte) (Tag.SHORT_LIST | length)); else { - out.write(Tag.LIST); + write(Tag.LIST); writeLength(length); } for(Object o : c) writeObject(o); @@ -145,18 +156,18 @@ class WriterImpl implements Writer { } public void writeListStart() throws IOException { - out.write(Tag.LIST_START); + write(Tag.LIST_START); } public void writeListEnd() throws IOException { - out.write(Tag.END); + write(Tag.END); } public void writeMap(Map<?, ?> m) throws IOException { int length = m.size(); - if(length < 16) out.write((byte) (Tag.SHORT_MAP | length)); + if(length < 16) write((byte) (Tag.SHORT_MAP | length)); else { - out.write(Tag.MAP); + write(Tag.MAP); writeLength(length); } for(Entry<?, ?> e : m.entrySet()) { @@ -166,24 +177,34 @@ class WriterImpl implements Writer { } public void writeMapStart() throws IOException { - out.write(Tag.MAP_START); + write(Tag.MAP_START); } public void writeMapEnd() throws IOException { - out.write(Tag.END); + write(Tag.END); } public void writeNull() throws IOException { - out.write(Tag.NULL); + write(Tag.NULL); } public void writeUserDefinedId(int id) throws IOException { if(id < 0 || id > 255) throw new IllegalArgumentException(); if(id < 32) { - out.write((byte) (Tag.SHORT_USER | id)); + write((byte) (Tag.SHORT_USER | id)); } else { - out.write(Tag.USER); - out.write((byte) id); + write(Tag.USER); + write((byte) id); } } + + private void write(byte b) throws IOException { + out.write(b); + for(Consumer c : consumers) c.write(b); + } + + private void write(byte[] b) throws IOException { + out.write(b); + for(Consumer c : consumers) c.write(b, 0, b.length); + } } diff --git a/components/net/sf/briar/transport/MacConsumer.java b/components/net/sf/briar/transport/MacConsumer.java index dbf20af8d1..f494cf2a55 100644 --- a/components/net/sf/briar/transport/MacConsumer.java +++ b/components/net/sf/briar/transport/MacConsumer.java @@ -17,10 +17,6 @@ class MacConsumer implements Consumer { mac.update(b); } - public void write(byte[] b) { - mac.update(b); - } - public void write(byte[] b, int off, int len) { mac.update(b, off, len); } diff --git a/test/net/sf/briar/FileReadWriteTest.java b/test/net/sf/briar/FileReadWriteTest.java index a624fb2bcb..12c5ebb935 100644 --- a/test/net/sf/briar/FileReadWriteTest.java +++ b/test/net/sf/briar/FileReadWriteTest.java @@ -130,7 +130,12 @@ public class FileReadWriteTest extends TestCase { } @Test - public void testWriteFile() throws Exception { + public void testWriteAndRead() throws Exception { + write(); + read(); + } + + private void write() throws Exception { OutputStream out = new FileOutputStream(file); // Use Alice's secret for writing ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out, @@ -177,10 +182,7 @@ public class FileReadWriteTest extends TestCase { assertTrue(file.length() > message.getSize()); } - @Test - public void testWriteAndReadFile() throws Exception { - - testWriteFile(); + private void read() throws Exception { InputStream in = new FileInputStream(file); byte[] iv = new byte[16]; diff --git a/test/net/sf/briar/serial/ReaderImplTest.java b/test/net/sf/briar/serial/ReaderImplTest.java index 982f0f397a..539d0da49c 100644 --- a/test/net/sf/briar/serial/ReaderImplTest.java +++ b/test/net/sf/briar/serial/ReaderImplTest.java @@ -416,10 +416,6 @@ public class ReaderImplTest extends TestCase { out.write(b); } - public void write(byte[] b) throws IOException { - out.write(b); - } - public void write(byte[] b, int off, int len) throws IOException { out.write(b, off, len); } -- GitLab