diff --git a/api/net/sf/briar/api/serial/Consumer.java b/api/net/sf/briar/api/serial/Consumer.java index 4091648d0f087b6f4b96745c674d7586c1c666e8..4d08ceb3825d0ed6a34df7e2c9bbf0138c54aaf2 100644 --- a/api/net/sf/briar/api/serial/Consumer.java +++ b/api/net/sf/briar/api/serial/Consumer.java @@ -6,7 +6,5 @@ public interface Consumer { void write(byte b) throws IOException; - void write(byte[] b) throws IOException; - void write(byte[] b, int off, int len) throws IOException; } diff --git a/api/net/sf/briar/api/serial/Writer.java b/api/net/sf/briar/api/serial/Writer.java index 887c1e51b32383fcb2486b13557ecb93e130331a..9b2b1482d31bcb477f4f2535180704272f01222a 100644 --- a/api/net/sf/briar/api/serial/Writer.java +++ b/api/net/sf/briar/api/serial/Writer.java @@ -6,6 +6,9 @@ import java.util.Map; public interface Writer { + void addConsumer(Consumer c); + void removeConsumer(Consumer c); + void writeBoolean(boolean b) throws IOException; void writeUint7(byte b) throws IOException; diff --git a/components/net/sf/briar/protocol/CopyingConsumer.java b/components/net/sf/briar/protocol/CopyingConsumer.java index 6e84d5929a90c64d1ac9d0132ea6faf6b286cf60..97129ab32713f6f08521073d0e103cff0c78e02b 100644 --- a/components/net/sf/briar/protocol/CopyingConsumer.java +++ b/components/net/sf/briar/protocol/CopyingConsumer.java @@ -18,10 +18,6 @@ class CopyingConsumer implements Consumer { out.write(b); } - public void write(byte[] b) throws IOException { - out.write(b); - } - public void write(byte[] b, int off, int len) throws IOException { out.write(b, off, len); } diff --git a/components/net/sf/briar/protocol/CountingConsumer.java b/components/net/sf/briar/protocol/CountingConsumer.java index bb032341e06040e2bbb7996f5d060671cf0a41d5..7722ce5437cbaf3540e0f88d17722d98630a854c 100644 --- a/components/net/sf/briar/protocol/CountingConsumer.java +++ b/components/net/sf/briar/protocol/CountingConsumer.java @@ -27,11 +27,6 @@ class CountingConsumer implements Consumer { if(count > limit) throw new FormatException(); } - public void write(byte[] b) throws IOException { - count += b.length; - if(count > limit) throw new FormatException(); - } - public void write(byte[] b, int off, int len) throws IOException { count += len; if(count > limit) throw new FormatException(); diff --git a/components/net/sf/briar/protocol/DigestingConsumer.java b/components/net/sf/briar/protocol/DigestingConsumer.java index 78826ff45be8aee78b0ccfed14c07a129b541294..29a0c260ef853632f4cae785396e1bf46a9ad660 100644 --- a/components/net/sf/briar/protocol/DigestingConsumer.java +++ b/components/net/sf/briar/protocol/DigestingConsumer.java @@ -17,10 +17,6 @@ class DigestingConsumer implements Consumer { messageDigest.update(b); } - public void write(byte[] b) { - messageDigest.update(b); - } - public void write(byte[] b, int off, int len) { messageDigest.update(b, off, len); } diff --git a/components/net/sf/briar/protocol/MessageEncoderImpl.java b/components/net/sf/briar/protocol/MessageEncoderImpl.java index 833bc3e418c75f3eea20aaded82305e3364d1595..acf1aeebf0ff539e606fb6dc369d53484bd378de 100644 --- a/components/net/sf/briar/protocol/MessageEncoderImpl.java +++ b/components/net/sf/briar/protocol/MessageEncoderImpl.java @@ -17,6 +17,7 @@ import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.MessageEncoder; import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.Types; +import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Writer; import net.sf.briar.api.serial.WriterFactory; @@ -24,14 +25,15 @@ import com.google.inject.Inject; class MessageEncoderImpl implements MessageEncoder { - private final Signature signature; + private final Signature authorSignature, groupSignature; private final SecureRandom random; private final MessageDigest messageDigest; private final WriterFactory writerFactory; @Inject MessageEncoderImpl(CryptoComponent crypto, WriterFactory writerFactory) { - signature = crypto.getSignature(); + authorSignature = crypto.getSignature(); + groupSignature = crypto.getSignature(); random = crypto.getSecureRandom(); messageDigest = crypto.getMessageDigest(); this.writerFactory = writerFactory; @@ -71,9 +73,23 @@ class MessageEncoderImpl implements MessageEncoder { if(body.length > Message.MAX_BODY_LENGTH) throw new IllegalArgumentException(); - long timestamp = System.currentTimeMillis(); ByteArrayOutputStream out = new ByteArrayOutputStream(); Writer w = writerFactory.createWriter(out); + // Initialise the consumers + Consumer digestingConsumer = new DigestingConsumer(messageDigest); + w.addConsumer(digestingConsumer); + Consumer authorConsumer = null; + if(authorKey != null) { + authorSignature.initSign(authorKey); + authorConsumer = new SigningConsumer(authorSignature); + w.addConsumer(authorConsumer); + } + Consumer groupConsumer = null; + if(groupKey != null) { + groupSignature.initSign(groupKey); + groupConsumer = new SigningConsumer(groupSignature); + w.addConsumer(groupConsumer); + } // Write the message w.writeUserDefinedId(Types.MESSAGE); if(parent == null) w.writeNull(); @@ -82,6 +98,7 @@ class MessageEncoderImpl implements MessageEncoder { else group.writeTo(w); if(author == null) w.writeNull(); else author.writeTo(w); + long timestamp = System.currentTimeMillis(); w.writeInt64(timestamp); byte[] salt = new byte[Message.SALT_LENGTH]; random.nextBytes(salt); @@ -91,9 +108,8 @@ class MessageEncoderImpl implements MessageEncoder { if(authorKey == null) { w.writeNull(); } else { - signature.initSign(authorKey); - signature.update(out.toByteArray()); - byte[] sig = signature.sign(); + w.removeConsumer(authorConsumer); + byte[] sig = authorSignature.sign(); if(sig.length > Message.MAX_SIGNATURE_LENGTH) throw new IllegalArgumentException(); w.writeBytes(sig); @@ -102,17 +118,15 @@ class MessageEncoderImpl implements MessageEncoder { if(groupKey == null) { w.writeNull(); } else { - signature.initSign(groupKey); - signature.update(out.toByteArray()); - byte[] sig = signature.sign(); + w.removeConsumer(groupConsumer); + byte[] sig = groupSignature.sign(); if(sig.length > Message.MAX_SIGNATURE_LENGTH) throw new IllegalArgumentException(); w.writeBytes(sig); } // Hash the message, including the signatures, to get the message ID + w.removeConsumer(digestingConsumer); byte[] raw = out.toByteArray(); - messageDigest.reset(); - messageDigest.update(raw); MessageId id = new MessageId(messageDigest.digest()); GroupId groupId = group == null ? null : group.getId(); AuthorId authorId = author == null ? null : author.getId(); diff --git a/components/net/sf/briar/protocol/SigningConsumer.java b/components/net/sf/briar/protocol/SigningConsumer.java new file mode 100644 index 0000000000000000000000000000000000000000..c3274ba7ff7ca5bdb0f74c0bcb5c1d6b5ec82c13 --- /dev/null +++ b/components/net/sf/briar/protocol/SigningConsumer.java @@ -0,0 +1,33 @@ +package net.sf.briar.protocol; + +import java.io.IOException; +import java.security.Signature; +import java.security.SignatureException; + +import net.sf.briar.api.serial.Consumer; + +/** A consumer that passes its input through a signature. */ +class SigningConsumer implements Consumer { + + private final Signature signature; + + SigningConsumer(Signature signature) { + this.signature = signature; + } + + public void write(byte b) throws IOException { + try { + signature.update(b); + } catch(SignatureException e) { + throw new IOException(e.getMessage()); + } + } + + public void write(byte[] b, int off, int len) throws IOException { + try { + signature.update(b, off, len); + } catch(SignatureException e) { + throw new IOException(e.getMessage()); + } + } +} diff --git a/components/net/sf/briar/serial/ReaderImpl.java b/components/net/sf/briar/serial/ReaderImpl.java index 600e0abdf05397588eb64ea1f8a5960b952794fe..acbf2e4922fb23f920134fb2fa9ff2db8a6831f1 100644 --- a/components/net/sf/briar/serial/ReaderImpl.java +++ b/components/net/sf/briar/serial/ReaderImpl.java @@ -19,8 +19,8 @@ class ReaderImpl implements Reader { private static final byte[] EMPTY_BUFFER = new byte[] {}; private final InputStream in; + private final List<Consumer> consumers = new ArrayList<Consumer>(0); - private Consumer[] consumers = new Consumer[] {}; private ObjectReader<?>[] objectReaders = new ObjectReader<?>[] {}; private boolean hasLookahead = false, eof = false; private byte next, nextNext; @@ -89,24 +89,11 @@ class ReaderImpl implements Reader { } public void addConsumer(Consumer c) { - Consumer[] newConsumers = new Consumer[consumers.length + 1]; - System.arraycopy(consumers, 0, newConsumers, 0, consumers.length); - newConsumers[consumers.length] = c; - consumers = newConsumers; + consumers.add(c); } public void removeConsumer(Consumer c) { - if(consumers.length == 0) throw new IllegalArgumentException(); - Consumer[] newConsumers = new Consumer[consumers.length - 1]; - boolean found = false; - for(int src = 0, dest = 0; src < consumers.length; src++, dest++) { - if(!found && consumers[src].equals(c)) { - found = true; - src++; - } else newConsumers[dest] = consumers[src]; - } - if(found) consumers = newConsumers; - else throw new IllegalArgumentException(); + if(!consumers.remove(c)) throw new IllegalArgumentException(); } public void addObjectReader(int id, ObjectReader<?> o) { diff --git a/components/net/sf/briar/serial/WriterImpl.java b/components/net/sf/briar/serial/WriterImpl.java index ac786c57fdc55654af68964bd9b8df16c357f7ee..d491a6a5aecccc87684c6d79784137dda8823366 100644 --- a/components/net/sf/briar/serial/WriterImpl.java +++ b/components/net/sf/briar/serial/WriterImpl.java @@ -2,70 +2,81 @@ package net.sf.briar.serial; import java.io.IOException; import java.io.OutputStream; +import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Map.Entry; import net.sf.briar.api.Bytes; +import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.Writable; import net.sf.briar.api.serial.Writer; class WriterImpl implements Writer { private final OutputStream out; + private final List<Consumer> consumers = new ArrayList<Consumer>(0); WriterImpl(OutputStream out) { this.out = out; } + public void addConsumer(Consumer c) { + consumers.add(c); + } + + public void removeConsumer(Consumer c) { + if(!consumers.remove(c)) throw new IllegalArgumentException(); + } + public void writeBoolean(boolean b) throws IOException { - if(b) out.write(Tag.TRUE); - else out.write(Tag.FALSE); + if(b) write(Tag.TRUE); + else write(Tag.FALSE); } public void writeUint7(byte b) throws IOException { if(b < 0) throw new IllegalArgumentException(); - out.write(b); + write(b); } public void writeInt8(byte b) throws IOException { - out.write(Tag.INT8); - out.write(b); + write(Tag.INT8); + write(b); } public void writeInt16(short s) throws IOException { - out.write(Tag.INT16); - out.write((byte) (s >> 8)); - out.write((byte) ((s << 8) >> 8)); + write(Tag.INT16); + write((byte) (s >> 8)); + write((byte) ((s << 8) >> 8)); } public void writeInt32(int i) throws IOException { - out.write(Tag.INT32); + write(Tag.INT32); writeInt32Bits(i); } private void writeInt32Bits(int i) throws IOException { - out.write((byte) (i >> 24)); - out.write((byte) ((i << 8) >> 24)); - out.write((byte) ((i << 16) >> 24)); - out.write((byte) ((i << 24) >> 24)); + write((byte) (i >> 24)); + write((byte) ((i << 8) >> 24)); + write((byte) ((i << 16) >> 24)); + write((byte) ((i << 24) >> 24)); } public void writeInt64(long l) throws IOException { - out.write(Tag.INT64); + write(Tag.INT64); writeInt64Bits(l); } private void writeInt64Bits(long l) throws IOException { - out.write((byte) (l >> 56)); - out.write((byte) ((l << 8) >> 56)); - out.write((byte) ((l << 16) >> 56)); - out.write((byte) ((l << 24) >> 56)); - out.write((byte) ((l << 32) >> 56)); - out.write((byte) ((l << 40) >> 56)); - out.write((byte) ((l << 48) >> 56)); - out.write((byte) ((l << 56) >> 56)); + write((byte) (l >> 56)); + write((byte) ((l << 8) >> 56)); + write((byte) ((l << 16) >> 56)); + write((byte) ((l << 24) >> 56)); + write((byte) ((l << 32) >> 56)); + write((byte) ((l << 40) >> 56)); + write((byte) ((l << 48) >> 56)); + write((byte) ((l << 56) >> 56)); } public void writeIntAny(long l) throws IOException { @@ -81,23 +92,23 @@ class WriterImpl implements Writer { } public void writeFloat32(float f) throws IOException { - out.write(Tag.FLOAT32); + write(Tag.FLOAT32); writeInt32Bits(Float.floatToRawIntBits(f)); } public void writeFloat64(double d) throws IOException { - out.write(Tag.FLOAT64); + write(Tag.FLOAT64); writeInt64Bits(Double.doubleToRawLongBits(d)); } public void writeString(String s) throws IOException { byte[] b = s.getBytes("UTF-8"); - if(b.length < 16) out.write((byte) (Tag.SHORT_STRING | b.length)); + if(b.length < 16) write((byte) (Tag.SHORT_STRING | b.length)); else { - out.write(Tag.STRING); + write(Tag.STRING); writeLength(b.length); } - out.write(b); + write(b); } private void writeLength(int i) throws IOException { @@ -109,19 +120,19 @@ class WriterImpl implements Writer { } public void writeBytes(byte[] b) throws IOException { - if(b.length < 16) out.write((byte) (Tag.SHORT_BYTES | b.length)); + if(b.length < 16) write((byte) (Tag.SHORT_BYTES | b.length)); else { - out.write(Tag.BYTES); + write(Tag.BYTES); writeLength(b.length); } - out.write(b); + write(b); } public void writeList(Collection<?> c) throws IOException { int length = c.size(); - if(length < 16) out.write((byte) (Tag.SHORT_LIST | length)); + if(length < 16) write((byte) (Tag.SHORT_LIST | length)); else { - out.write(Tag.LIST); + write(Tag.LIST); writeLength(length); } for(Object o : c) writeObject(o); @@ -145,18 +156,18 @@ class WriterImpl implements Writer { } public void writeListStart() throws IOException { - out.write(Tag.LIST_START); + write(Tag.LIST_START); } public void writeListEnd() throws IOException { - out.write(Tag.END); + write(Tag.END); } public void writeMap(Map<?, ?> m) throws IOException { int length = m.size(); - if(length < 16) out.write((byte) (Tag.SHORT_MAP | length)); + if(length < 16) write((byte) (Tag.SHORT_MAP | length)); else { - out.write(Tag.MAP); + write(Tag.MAP); writeLength(length); } for(Entry<?, ?> e : m.entrySet()) { @@ -166,24 +177,34 @@ class WriterImpl implements Writer { } public void writeMapStart() throws IOException { - out.write(Tag.MAP_START); + write(Tag.MAP_START); } public void writeMapEnd() throws IOException { - out.write(Tag.END); + write(Tag.END); } public void writeNull() throws IOException { - out.write(Tag.NULL); + write(Tag.NULL); } public void writeUserDefinedId(int id) throws IOException { if(id < 0 || id > 255) throw new IllegalArgumentException(); if(id < 32) { - out.write((byte) (Tag.SHORT_USER | id)); + write((byte) (Tag.SHORT_USER | id)); } else { - out.write(Tag.USER); - out.write((byte) id); + write(Tag.USER); + write((byte) id); } } + + private void write(byte b) throws IOException { + out.write(b); + for(Consumer c : consumers) c.write(b); + } + + private void write(byte[] b) throws IOException { + out.write(b); + for(Consumer c : consumers) c.write(b, 0, b.length); + } } diff --git a/components/net/sf/briar/transport/MacConsumer.java b/components/net/sf/briar/transport/MacConsumer.java index dbf20af8d12b458e1116511cc5001ee8e1a97322..f494cf2a554c1fce2d42ba75241b5cc3d6d58829 100644 --- a/components/net/sf/briar/transport/MacConsumer.java +++ b/components/net/sf/briar/transport/MacConsumer.java @@ -17,10 +17,6 @@ class MacConsumer implements Consumer { mac.update(b); } - public void write(byte[] b) { - mac.update(b); - } - public void write(byte[] b, int off, int len) { mac.update(b, off, len); } diff --git a/test/net/sf/briar/FileReadWriteTest.java b/test/net/sf/briar/FileReadWriteTest.java index a624fb2bcb973b78b14714d92bf1bf463ab2e810..12c5ebb9353d032e00600d3ad8456f3a45c863e2 100644 --- a/test/net/sf/briar/FileReadWriteTest.java +++ b/test/net/sf/briar/FileReadWriteTest.java @@ -130,7 +130,12 @@ public class FileReadWriteTest extends TestCase { } @Test - public void testWriteFile() throws Exception { + public void testWriteAndRead() throws Exception { + write(); + read(); + } + + private void write() throws Exception { OutputStream out = new FileOutputStream(file); // Use Alice's secret for writing ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out, @@ -177,10 +182,7 @@ public class FileReadWriteTest extends TestCase { assertTrue(file.length() > message.getSize()); } - @Test - public void testWriteAndReadFile() throws Exception { - - testWriteFile(); + private void read() throws Exception { InputStream in = new FileInputStream(file); byte[] iv = new byte[16]; diff --git a/test/net/sf/briar/serial/ReaderImplTest.java b/test/net/sf/briar/serial/ReaderImplTest.java index 982f0f397a036739b750f4197a139138c277f54f..539d0da49c322080a338bf5c59c2c5cd1b9f183d 100644 --- a/test/net/sf/briar/serial/ReaderImplTest.java +++ b/test/net/sf/briar/serial/ReaderImplTest.java @@ -416,10 +416,6 @@ public class ReaderImplTest extends TestCase { out.write(b); } - public void write(byte[] b) throws IOException { - out.write(b); - } - public void write(byte[] b, int off, int len) throws IOException { out.write(b, off, len); }