From fb528a85adaa7a38738730e52d5b6066cfbe20df Mon Sep 17 00:00:00 2001 From: akwizgran <akwizgran@users.sourceforge.net> Date: Tue, 19 Jul 2011 17:17:45 +0100 Subject: [PATCH] Nested user-defined objects (and collections of them) can now be read by registering ObjectReaders with the Reader. --- .../sf/briar/api/protocol/BundleWriter.java | 5 +- api/net/sf/briar/api/protocol/Tags.java | 22 +++-- api/net/sf/briar/api/serial/ObjectReader.java | 8 ++ api/net/sf/briar/api/serial/Reader.java | 4 + api/net/sf/briar/api/serial/Writer.java | 4 +- .../net/sf/briar/protocol/BatchIdReader.java | 18 ++++ .../sf/briar/protocol/BundleReaderImpl.java | 52 ++++------- .../sf/briar/protocol/BundleWriterImpl.java | 25 ++--- .../net/sf/briar/protocol/GroupIdReader.java | 18 ++++ .../net/sf/briar/protocol/HeaderFactory.java | 4 +- .../sf/briar/protocol/HeaderFactoryImpl.java | 11 ++- .../sf/briar/protocol/MessageEncoderImpl.java | 20 ++-- .../net/sf/briar/protocol/MessageReader.java | 88 +++++++++++++++++- .../sf/briar/protocol/MessageReaderImpl.java | 92 ------------------- .../net/sf/briar/protocol/ProtocolModule.java | 3 - .../net/sf/briar/serial/ReaderImpl.java | 29 ++++++ .../net/sf/briar/serial/WriterImpl.java | 11 ++- test/build.xml | 1 + .../briar/protocol/BundleReadWriteTest.java | 8 +- .../SigningDigestingOutputStreamTest.java | 81 ++++++++++++++++ test/net/sf/briar/serial/ReaderImplTest.java | 70 ++++++++++++++ test/net/sf/briar/serial/WriterImplTest.java | 17 ++++ 22 files changed, 414 insertions(+), 177 deletions(-) create mode 100644 api/net/sf/briar/api/serial/ObjectReader.java create mode 100644 components/net/sf/briar/protocol/BatchIdReader.java create mode 100644 components/net/sf/briar/protocol/GroupIdReader.java delete mode 100644 components/net/sf/briar/protocol/MessageReaderImpl.java create mode 100644 test/net/sf/briar/protocol/SigningDigestingOutputStreamTest.java diff --git a/api/net/sf/briar/api/protocol/BundleWriter.java b/api/net/sf/briar/api/protocol/BundleWriter.java index fff3e66d92..029f1a340a 100644 --- a/api/net/sf/briar/api/protocol/BundleWriter.java +++ b/api/net/sf/briar/api/protocol/BundleWriter.java @@ -2,6 +2,7 @@ package net.sf.briar.api.protocol; import java.io.IOException; import java.security.GeneralSecurityException; +import java.util.Collection; import java.util.Map; import net.sf.briar.api.serial.Raw; @@ -16,12 +17,12 @@ public interface BundleWriter { long getRemainingCapacity() throws IOException; /** Adds a header to the bundle. */ - void addHeader(Iterable<BatchId> acks, Iterable<GroupId> subs, + void addHeader(Collection<BatchId> acks, Collection<GroupId> subs, Map<String, String> transports) throws IOException, GeneralSecurityException; /** Adds a batch of messages to the bundle and returns its identifier. */ - BatchId addBatch(Iterable<Raw> messages) throws IOException, + BatchId addBatch(Collection<Raw> messages) throws IOException, GeneralSecurityException; /** Finishes writing the bundle. */ diff --git a/api/net/sf/briar/api/protocol/Tags.java b/api/net/sf/briar/api/protocol/Tags.java index f2767b1ac9..c3e2b8193e 100644 --- a/api/net/sf/briar/api/protocol/Tags.java +++ b/api/net/sf/briar/api/protocol/Tags.java @@ -2,15 +2,17 @@ package net.sf.briar.api.protocol; public interface Tags { - static final int HEADER = 0; - static final int BATCH_ID = 1; - static final int GROUP_ID = 2; - static final int TIMESTAMP = 3; - static final int SIGNATURE = 4; - static final int BATCH = 5; + static final int AUTHOR_ID = 1; + static final int BATCH = 2; + static final int BATCH_ID = 3; + static final int GROUP_ID = 4; + static final int HEADER = 5; static final int MESSAGE = 6; - static final int MESSAGE_ID = 7; - static final int AUTHOR = 8; - static final int MESSAGE_BODY = 9; - static final int AUTHOR_ID = 10; + static final int MESSAGE_BODY = 7; + static final int MESSAGE_ID = 8; + static final int NICKNAME = 9; + static final int PUBLIC_KEY = 10; + static final int SIGNATURE = 12; + static final int TIMESTAMP = 13; + static final int TRANSPORTS = 14; } diff --git a/api/net/sf/briar/api/serial/ObjectReader.java b/api/net/sf/briar/api/serial/ObjectReader.java new file mode 100644 index 0000000000..012f5f800a --- /dev/null +++ b/api/net/sf/briar/api/serial/ObjectReader.java @@ -0,0 +1,8 @@ +package net.sf.briar.api.serial; + +import java.io.IOException; + +public interface ObjectReader<T> { + + T readObject(Reader r) throws IOException; +} diff --git a/api/net/sf/briar/api/serial/Reader.java b/api/net/sf/briar/api/serial/Reader.java index e9fd9d0d27..731b860ce0 100644 --- a/api/net/sf/briar/api/serial/Reader.java +++ b/api/net/sf/briar/api/serial/Reader.java @@ -12,6 +12,9 @@ public interface Reader { void addConsumer(Consumer c); void removeConsumer(Consumer c); + void addObjectReader(int tag, ObjectReader<?> o); + void removeObjectReader(int tag); + boolean hasBoolean() throws IOException; boolean readBoolean() throws IOException; @@ -60,4 +63,5 @@ public interface Reader { boolean hasUserDefinedTag() throws IOException; int readUserDefinedTag() throws IOException; void readUserDefinedTag(int tag) throws IOException; + <T> T readUserDefinedObject(int tag) throws IOException; } diff --git a/api/net/sf/briar/api/serial/Writer.java b/api/net/sf/briar/api/serial/Writer.java index e1c4a2721e..ba4f3d2556 100644 --- a/api/net/sf/briar/api/serial/Writer.java +++ b/api/net/sf/briar/api/serial/Writer.java @@ -1,7 +1,7 @@ package net.sf.briar.api.serial; import java.io.IOException; -import java.util.List; +import java.util.Collection; import java.util.Map; public interface Writer { @@ -25,7 +25,7 @@ public interface Writer { void writeRaw(byte[] b) throws IOException; void writeRaw(Raw r) throws IOException; - void writeList(List<?> l) throws IOException; + void writeList(Collection<?> c) throws IOException; void writeListStart() throws IOException; void writeListEnd() throws IOException; diff --git a/components/net/sf/briar/protocol/BatchIdReader.java b/components/net/sf/briar/protocol/BatchIdReader.java new file mode 100644 index 0000000000..5114141f19 --- /dev/null +++ b/components/net/sf/briar/protocol/BatchIdReader.java @@ -0,0 +1,18 @@ +package net.sf.briar.protocol; + +import java.io.IOException; + +import net.sf.briar.api.protocol.BatchId; +import net.sf.briar.api.protocol.UniqueId; +import net.sf.briar.api.serial.FormatException; +import net.sf.briar.api.serial.ObjectReader; +import net.sf.briar.api.serial.Reader; + +public class BatchIdReader implements ObjectReader<BatchId> { + + public BatchId readObject(Reader r) throws IOException { + byte[] b = r.readRaw(); + if(b.length != UniqueId.LENGTH) throw new FormatException(); + return new BatchId(b); + } +} diff --git a/components/net/sf/briar/protocol/BundleReaderImpl.java b/components/net/sf/briar/protocol/BundleReaderImpl.java index 2c80d9b1b2..c5ff59d091 100644 --- a/components/net/sf/briar/protocol/BundleReaderImpl.java +++ b/components/net/sf/briar/protocol/BundleReaderImpl.java @@ -6,11 +6,9 @@ import java.security.MessageDigest; import java.security.PublicKey; import java.security.Signature; import java.security.SignatureException; -import java.util.ArrayList; -import java.util.HashSet; +import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.Set; import net.sf.briar.api.protocol.Batch; import net.sf.briar.api.protocol.BatchId; @@ -19,8 +17,8 @@ import net.sf.briar.api.protocol.GroupId; import net.sf.briar.api.protocol.Header; import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.Tags; -import net.sf.briar.api.protocol.UniqueId; import net.sf.briar.api.serial.FormatException; +import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.Reader; class BundleReaderImpl implements BundleReader { @@ -31,19 +29,19 @@ class BundleReaderImpl implements BundleReader { private final PublicKey publicKey; private final Signature signature; private final MessageDigest messageDigest; - private final MessageReader messageReader; + private final ObjectReader<Message> messageReader; private final HeaderFactory headerFactory; private final BatchFactory batchFactory; private State state = State.START; BundleReaderImpl(Reader reader, PublicKey publicKey, Signature signature, - MessageDigest messageDigest, MessageReader messageParser, + MessageDigest messageDigest, ObjectReader<Message> messageReader, HeaderFactory headerFactory, BatchFactory batchFactory) { this.reader = reader; this.publicKey = publicKey; this.signature = signature; this.messageDigest = messageDigest; - this.messageReader = messageParser; + this.messageReader = messageReader; this.headerFactory = headerFactory; this.batchFactory = batchFactory; } @@ -55,37 +53,27 @@ class BundleReaderImpl implements BundleReader { CountingConsumer counting = new CountingConsumer(Header.MAX_SIZE); SigningConsumer signing = new SigningConsumer(signature); signature.initVerify(publicKey); + // Read the initial tag + reader.readUserDefinedTag(Tags.HEADER); // Read the signed data reader.addConsumer(counting); reader.addConsumer(signing); - reader.readUserDefinedTag(Tags.HEADER); + reader.addObjectReader(Tags.BATCH_ID, new BatchIdReader()); + reader.addObjectReader(Tags.GROUP_ID, new GroupIdReader()); // Acks - Set<BatchId> acks = new HashSet<BatchId>(); - reader.readListStart(); - while(!reader.hasListEnd()) { - reader.readUserDefinedTag(Tags.BATCH_ID); - byte[] b = reader.readRaw(); - if(b.length != UniqueId.LENGTH) throw new FormatException(); - acks.add(new BatchId(b)); - } - reader.readListEnd(); + Collection<BatchId> acks = reader.readList(BatchId.class); // Subs - Set<GroupId> subs = new HashSet<GroupId>(); - reader.readListStart(); - while(!reader.hasListEnd()) { - reader.readUserDefinedTag(Tags.GROUP_ID); - byte[] b = reader.readRaw(); - if(b.length != UniqueId.LENGTH) throw new FormatException(); - subs.add(new GroupId(b)); - } - reader.readListEnd(); + Collection<GroupId> subs = reader.readList(GroupId.class); // Transports + reader.readUserDefinedTag(Tags.TRANSPORTS); Map<String, String> transports = reader.readMap(String.class, String.class); // Timestamp reader.readUserDefinedTag(Tags.TIMESTAMP); long timestamp = reader.readInt64(); if(timestamp < 0L) throw new FormatException(); + reader.removeObjectReader(Tags.GROUP_ID); + reader.removeObjectReader(Tags.BATCH_ID); reader.removeConsumer(signing); // Read and verify the signature reader.readUserDefinedTag(Tags.SIGNATURE); @@ -115,17 +103,15 @@ class BundleReaderImpl implements BundleReader { messageDigest.reset(); SigningConsumer signing = new SigningConsumer(signature); signature.initVerify(publicKey); + // Read the initial tag + reader.readUserDefinedTag(Tags.BATCH); // Read the signed data reader.addConsumer(counting); reader.addConsumer(digesting); reader.addConsumer(signing); - reader.readUserDefinedTag(Tags.BATCH); - List<Message> messages = new ArrayList<Message>(); - reader.readListStart(); - while(!reader.hasListEnd()) { - messages.add(messageReader.readMessage(reader)); - } - reader.readListEnd(); + reader.addObjectReader(Tags.MESSAGE, messageReader); + List<Message> messages = reader.readList(Message.class); + reader.removeObjectReader(Tags.MESSAGE); reader.removeConsumer(signing); // Read and verify the signature reader.readUserDefinedTag(Tags.SIGNATURE); diff --git a/components/net/sf/briar/protocol/BundleWriterImpl.java b/components/net/sf/briar/protocol/BundleWriterImpl.java index 6f4802e7a1..dc5aab7c36 100644 --- a/components/net/sf/briar/protocol/BundleWriterImpl.java +++ b/components/net/sf/briar/protocol/BundleWriterImpl.java @@ -6,6 +6,7 @@ import java.security.GeneralSecurityException; import java.security.MessageDigest; import java.security.PrivateKey; import java.security.Signature; +import java.util.Collection; import java.util.Map; import net.sf.briar.api.protocol.BatchId; @@ -44,24 +45,22 @@ class BundleWriterImpl implements BundleWriter { return capacity - writer.getBytesWritten(); } - public void addHeader(Iterable<BatchId> acks, Iterable<GroupId> subs, + public void addHeader(Collection<BatchId> acks, Collection<GroupId> subs, Map<String, String> transports) throws IOException, GeneralSecurityException { if(state != State.START) throw new IllegalStateException(); // Initialise the output stream signature.initSign(privateKey); + // Write the initial tag + writer.writeUserDefinedTag(Tags.HEADER); // Write the data to be signed out.setSigning(true); - writer.writeUserDefinedTag(Tags.HEADER); // Acks - writer.writeListStart(); - for(BatchId ack : acks) ack.writeTo(writer); - writer.writeListEnd(); + writer.writeList(acks); // Subs - writer.writeListStart(); - for(GroupId sub : subs) sub.writeTo(writer); - writer.writeListEnd(); + writer.writeList(subs); // Transports + writer.writeUserDefinedTag(Tags.TRANSPORTS); writer.writeMap(transports); // Timestamp writer.writeUserDefinedTag(Tags.TIMESTAMP); @@ -75,7 +74,7 @@ class BundleWriterImpl implements BundleWriter { state = State.FIRST_BATCH; } - public BatchId addBatch(Iterable<Raw> messages) throws IOException, + public BatchId addBatch(Collection<Raw> messages) throws IOException, GeneralSecurityException { if(state == State.FIRST_BATCH) { writer.writeListStart(); @@ -85,13 +84,17 @@ class BundleWriterImpl implements BundleWriter { // Initialise the output stream signature.initSign(privateKey); messageDigest.reset(); + // Write the initial tag + writer.writeUserDefinedTag(Tags.BATCH); // Write the data to be signed out.setDigesting(true); out.setSigning(true); - writer.writeUserDefinedTag(Tags.BATCH); writer.writeListStart(); // Bypass the writer and write the raw messages directly - for(Raw message : messages) out.write(message.getBytes()); + for(Raw message : messages) { + writer.writeUserDefinedTag(Tags.MESSAGE); + out.write(message.getBytes()); + } writer.writeListEnd(); out.setSigning(false); // Create and write the signature diff --git a/components/net/sf/briar/protocol/GroupIdReader.java b/components/net/sf/briar/protocol/GroupIdReader.java new file mode 100644 index 0000000000..bc1b1ab1be --- /dev/null +++ b/components/net/sf/briar/protocol/GroupIdReader.java @@ -0,0 +1,18 @@ +package net.sf.briar.protocol; + +import java.io.IOException; + +import net.sf.briar.api.protocol.GroupId; +import net.sf.briar.api.protocol.UniqueId; +import net.sf.briar.api.serial.FormatException; +import net.sf.briar.api.serial.ObjectReader; +import net.sf.briar.api.serial.Reader; + +public class GroupIdReader implements ObjectReader<GroupId> { + + public GroupId readObject(Reader r) throws IOException { + byte[] b = r.readRaw(); + if(b.length != UniqueId.LENGTH) throw new FormatException(); + return new GroupId(b); + } +} diff --git a/components/net/sf/briar/protocol/HeaderFactory.java b/components/net/sf/briar/protocol/HeaderFactory.java index a48a4f6747..52099b3541 100644 --- a/components/net/sf/briar/protocol/HeaderFactory.java +++ b/components/net/sf/briar/protocol/HeaderFactory.java @@ -1,7 +1,7 @@ package net.sf.briar.protocol; +import java.util.Collection; import java.util.Map; -import java.util.Set; import net.sf.briar.api.protocol.BatchId; import net.sf.briar.api.protocol.GroupId; @@ -9,6 +9,6 @@ import net.sf.briar.api.protocol.Header; interface HeaderFactory { - Header createHeader(Set<BatchId> acks, Set<GroupId> subs, + Header createHeader(Collection<BatchId> acks, Collection<GroupId> subs, Map<String, String> transports, long timestamp); } diff --git a/components/net/sf/briar/protocol/HeaderFactoryImpl.java b/components/net/sf/briar/protocol/HeaderFactoryImpl.java index 93bd6ba2c3..07a34f7b30 100644 --- a/components/net/sf/briar/protocol/HeaderFactoryImpl.java +++ b/components/net/sf/briar/protocol/HeaderFactoryImpl.java @@ -1,5 +1,7 @@ package net.sf.briar.protocol; +import java.util.Collection; +import java.util.HashSet; import java.util.Map; import java.util.Set; @@ -9,8 +11,11 @@ import net.sf.briar.api.protocol.Header; class HeaderFactoryImpl implements HeaderFactory { - public Header createHeader(Set<BatchId> acks, Set<GroupId> subs, - Map<String, String> transports, long timestamp) { - return new HeaderImpl(acks, subs, transports, timestamp); + public Header createHeader(Collection<BatchId> acks, + Collection<GroupId> subs, Map<String, String> transports, + long timestamp) { + Set<BatchId> ackSet = new HashSet<BatchId>(acks); + Set<GroupId> subSet = new HashSet<GroupId>(subs); + return new HeaderImpl(ackSet, subSet, transports, timestamp); } } diff --git a/components/net/sf/briar/protocol/MessageEncoderImpl.java b/components/net/sf/briar/protocol/MessageEncoderImpl.java index f59ba57eee..c3b1748764 100644 --- a/components/net/sf/briar/protocol/MessageEncoderImpl.java +++ b/components/net/sf/briar/protocol/MessageEncoderImpl.java @@ -33,23 +33,26 @@ class MessageEncoderImpl implements MessageEncoder { KeyPair keyPair, byte[] body) throws IOException, GeneralSecurityException { long timestamp = System.currentTimeMillis(); - byte[] encodedKey = keyPair.getPublic().getEncoded(); ByteArrayOutputStream out = new ByteArrayOutputStream(); Writer w = writerFactory.createWriter(out); - w.writeUserDefinedTag(Tags.MESSAGE); + // Write the message parent.writeTo(w); group.writeTo(w); w.writeUserDefinedTag(Tags.TIMESTAMP); w.writeInt64(timestamp); - w.writeUserDefinedTag(Tags.AUTHOR); + w.writeUserDefinedTag(Tags.NICKNAME); w.writeString(nick); - w.writeRaw(encodedKey); + w.writeUserDefinedTag(Tags.PUBLIC_KEY); + w.writeRaw(keyPair.getPublic().getEncoded()); w.writeUserDefinedTag(Tags.MESSAGE_BODY); w.writeRaw(body); + // Sign the message byte[] signable = out.toByteArray(); signature.initSign(keyPair.getPrivate()); signature.update(signable); byte[] sig = signature.sign(); + signable = null; + // Write the signature w.writeUserDefinedTag(Tags.SIGNATURE); w.writeRaw(sig); byte[] raw = out.toByteArray(); @@ -61,13 +64,14 @@ class MessageEncoderImpl implements MessageEncoder { // The author ID is the hash of the author's nick and public key out.reset(); w = writerFactory.createWriter(out); - w.writeUserDefinedTag(Tags.AUTHOR); + w.writeUserDefinedTag(Tags.NICKNAME); w.writeString(nick); - w.writeRaw(encodedKey); + w.writeUserDefinedTag(Tags.PUBLIC_KEY); + w.writeRaw(keyPair.getPublic().getEncoded()); w.close(); messageDigest.reset(); messageDigest.update(out.toByteArray()); - AuthorId author = new AuthorId(messageDigest.digest()); - return new MessageImpl(id, parent, group, author, timestamp, raw); + AuthorId authorId = new AuthorId(messageDigest.digest()); + return new MessageImpl(id, parent, group, authorId, timestamp, raw); } } diff --git a/components/net/sf/briar/protocol/MessageReader.java b/components/net/sf/briar/protocol/MessageReader.java index 904df2f90e..74022dfed5 100644 --- a/components/net/sf/briar/protocol/MessageReader.java +++ b/components/net/sf/briar/protocol/MessageReader.java @@ -2,12 +2,94 @@ package net.sf.briar.protocol; import java.io.IOException; import java.security.GeneralSecurityException; +import java.security.MessageDigest; +import java.security.PublicKey; +import java.security.Signature; +import java.security.SignatureException; +import java.security.spec.InvalidKeySpecException; +import net.sf.briar.api.crypto.KeyParser; +import net.sf.briar.api.protocol.AuthorId; +import net.sf.briar.api.protocol.GroupId; import net.sf.briar.api.protocol.Message; +import net.sf.briar.api.protocol.MessageId; +import net.sf.briar.api.protocol.Tags; +import net.sf.briar.api.protocol.UniqueId; +import net.sf.briar.api.serial.FormatException; +import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.Reader; -interface MessageReader { +class MessageReader implements ObjectReader<Message> { - Message readMessage(Reader r) throws IOException, - GeneralSecurityException; + private final KeyParser keyParser; + private final Signature signature; + private final MessageDigest messageDigest; + + MessageReader(KeyParser keyParser, Signature signature, + MessageDigest messageDigest) { + this.keyParser = keyParser; + this.signature = signature; + this.messageDigest = messageDigest; + } + + public Message readObject(Reader reader) throws IOException { + CopyingConsumer copying = new CopyingConsumer(); + CountingConsumer counting = new CountingConsumer(Message.MAX_SIZE); + reader.addConsumer(copying); + reader.addConsumer(counting); + // Read the parent's message ID + reader.readUserDefinedTag(Tags.MESSAGE_ID); + byte[] b = reader.readRaw(); + if(b.length != UniqueId.LENGTH) throw new FormatException(); + MessageId parent = new MessageId(b); + // Read the group ID + reader.readUserDefinedTag(Tags.GROUP_ID); + b = reader.readRaw(); + if(b.length != UniqueId.LENGTH) throw new FormatException(); + GroupId group = new GroupId(b); + // Read the timestamp + reader.readUserDefinedTag(Tags.TIMESTAMP); + long timestamp = reader.readInt64(); + if(timestamp < 0L) throw new FormatException(); + // Hash the author's nick and public key to get the author ID + DigestingConsumer digesting = new DigestingConsumer(messageDigest); + messageDigest.reset(); + reader.addConsumer(digesting); + reader.readUserDefinedTag(Tags.NICKNAME); + reader.readString(); + reader.readUserDefinedTag(Tags.PUBLIC_KEY); + byte[] encodedKey = reader.readRaw(); + reader.removeConsumer(digesting); + AuthorId author = new AuthorId(messageDigest.digest()); + // Skip the message body + reader.readUserDefinedTag(Tags.MESSAGE_BODY); + reader.readRaw(); + // Record the length of the signed data + int messageLength = (int) counting.getCount(); + // Read the signature + reader.readUserDefinedTag(Tags.SIGNATURE); + byte[] sig = reader.readRaw(); + reader.removeConsumer(counting); + reader.removeConsumer(copying); + // Verify the signature + PublicKey publicKey; + try { + publicKey = keyParser.parsePublicKey(encodedKey); + } catch(InvalidKeySpecException e) { + throw new FormatException(); + } + byte[] raw = copying.getCopy(); + try { + signature.initVerify(publicKey); + signature.update(raw, 0, messageLength); + if(!signature.verify(sig)) throw new SignatureException(); + } catch(GeneralSecurityException e) { + throw new FormatException(); + } + // Hash the message, including the signature, to get the message ID + messageDigest.reset(); + messageDigest.update(raw); + MessageId id = new MessageId(messageDigest.digest()); + return new MessageImpl(id, parent, group, author, timestamp, raw); + } } diff --git a/components/net/sf/briar/protocol/MessageReaderImpl.java b/components/net/sf/briar/protocol/MessageReaderImpl.java deleted file mode 100644 index f865fc021a..0000000000 --- a/components/net/sf/briar/protocol/MessageReaderImpl.java +++ /dev/null @@ -1,92 +0,0 @@ -package net.sf.briar.protocol; - -import java.io.IOException; -import java.security.GeneralSecurityException; -import java.security.MessageDigest; -import java.security.PublicKey; -import java.security.Signature; -import java.security.SignatureException; -import java.security.spec.InvalidKeySpecException; - -import net.sf.briar.api.crypto.KeyParser; -import net.sf.briar.api.protocol.AuthorId; -import net.sf.briar.api.protocol.GroupId; -import net.sf.briar.api.protocol.Message; -import net.sf.briar.api.protocol.MessageId; -import net.sf.briar.api.protocol.Tags; -import net.sf.briar.api.protocol.UniqueId; -import net.sf.briar.api.serial.FormatException; -import net.sf.briar.api.serial.Reader; - -class MessageReaderImpl implements MessageReader { - - private final KeyParser keyParser; - private final Signature signature; - private final MessageDigest messageDigest; - - MessageReaderImpl(KeyParser keyParser, Signature signature, - MessageDigest messageDigest) { - this.keyParser = keyParser; - this.signature = signature; - this.messageDigest = messageDigest; - } - - public Message readMessage(Reader reader) throws IOException, - GeneralSecurityException { - CopyingConsumer copying = new CopyingConsumer(); - CountingConsumer counting = new CountingConsumer(Message.MAX_SIZE); - DigestingConsumer digesting = new DigestingConsumer(messageDigest); - messageDigest.reset(); - reader.addConsumer(copying); - reader.addConsumer(counting); - // Read the initial tag - reader.readUserDefinedTag(Tags.MESSAGE); - // Read the parent's message ID - reader.readUserDefinedTag(Tags.MESSAGE_ID); - byte[] b = reader.readRaw(); - if(b.length != UniqueId.LENGTH) throw new FormatException(); - MessageId parent = new MessageId(b); - // Read the group ID - reader.readUserDefinedTag(Tags.GROUP_ID); - b = reader.readRaw(); - if(b.length != UniqueId.LENGTH) throw new FormatException(); - GroupId group = new GroupId(b); - // Read the timestamp - reader.readUserDefinedTag(Tags.TIMESTAMP); - long timestamp = reader.readInt64(); - if(timestamp < 0L) throw new FormatException(); - // Hash the author's nick and public key to get the author ID - reader.addConsumer(digesting); - reader.readUserDefinedTag(Tags.AUTHOR); - reader.readString(); - byte[] encodedKey = reader.readRaw(); - reader.removeConsumer(digesting); - AuthorId author = new AuthorId(messageDigest.digest()); - // Skip the message body - reader.readUserDefinedTag(Tags.MESSAGE_BODY); - reader.readRaw(); - // Record the length of the signed data - int messageLength = (int) counting.getCount(); - // Read the signature - reader.readUserDefinedTag(Tags.SIGNATURE); - byte[] sig = reader.readRaw(); - reader.removeConsumer(counting); - reader.removeConsumer(copying); - // Verify the signature - PublicKey publicKey; - try { - publicKey = keyParser.parsePublicKey(encodedKey); - } catch(InvalidKeySpecException e) { - throw new FormatException(); - } - byte[] raw = copying.getCopy(); - signature.initVerify(publicKey); - signature.update(raw, 0, messageLength); - if(!signature.verify(sig)) throw new SignatureException(); - // Hash the message, including the signature, to get the message ID - messageDigest.reset(); - messageDigest.update(raw); - MessageId id = new MessageId(messageDigest.digest()); - return new MessageImpl(id, parent, group, author, timestamp, raw); - } -} diff --git a/components/net/sf/briar/protocol/ProtocolModule.java b/components/net/sf/briar/protocol/ProtocolModule.java index 18d350abb3..bedf1e3506 100644 --- a/components/net/sf/briar/protocol/ProtocolModule.java +++ b/components/net/sf/briar/protocol/ProtocolModule.java @@ -2,7 +2,6 @@ package net.sf.briar.protocol; import net.sf.briar.api.protocol.BundleReader; import net.sf.briar.api.protocol.BundleWriter; -import net.sf.briar.api.protocol.MessageEncoder; import com.google.inject.AbstractModule; @@ -14,7 +13,5 @@ public class ProtocolModule extends AbstractModule { bind(BundleReader.class).to(BundleReaderImpl.class); bind(BundleWriter.class).to(BundleWriterImpl.class); bind(HeaderFactory.class).to(HeaderFactoryImpl.class); - bind(MessageEncoder.class).to(MessageEncoderImpl.class); - bind(MessageReader.class).to(MessageReaderImpl.class); } } diff --git a/components/net/sf/briar/serial/ReaderImpl.java b/components/net/sf/briar/serial/ReaderImpl.java index ebdbf7af62..11d0c2ddaf 100644 --- a/components/net/sf/briar/serial/ReaderImpl.java +++ b/components/net/sf/briar/serial/ReaderImpl.java @@ -9,6 +9,7 @@ import java.util.Map; import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.FormatException; +import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.RawByteArray; import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.Tag; @@ -18,6 +19,9 @@ class ReaderImpl implements Reader { private static final byte[] EMPTY_BUFFER = new byte[] {}; private final InputStream in; + private final Map<Integer, ObjectReader<?>> objectReaders = + new HashMap<Integer, ObjectReader<?>>(); + private Consumer[] consumers = new Consumer[] {}; private boolean started = false, eof = false; private byte next; @@ -72,6 +76,14 @@ class ReaderImpl implements Reader { else throw new IllegalArgumentException(); } + public void addObjectReader(int tag, ObjectReader<?> o) { + objectReaders.put(tag, o); + } + + public void removeObjectReader(int tag) { + objectReaders.remove(tag); + } + public boolean hasBoolean() throws IOException { if(!started) readNext(true); if(eof) return false; @@ -346,6 +358,11 @@ class ReaderImpl implements Reader { private Object readObject() throws IOException { if(!started) throw new IllegalStateException(); + if(hasUserDefinedTag()) { + ObjectReader<?> o = objectReaders.get(readUserDefinedTag()); + if(o == null) throw new FormatException(); + return o.readObject(this); + } if(hasBoolean()) return Boolean.valueOf(readBoolean()); if(hasUint7()) return Byte.valueOf(readUint7()); if(hasInt8()) return Byte.valueOf(readInt8()); @@ -482,4 +499,16 @@ class ReaderImpl implements Reader { public void readUserDefinedTag(int tag) throws IOException { if(readUserDefinedTag() != tag) throw new FormatException(); } + + public <T> T readUserDefinedObject(int tag) throws IOException { + ObjectReader<?> o = objectReaders.get(tag); + if(o == null) throw new FormatException(); + try { + @SuppressWarnings("unchecked") + ObjectReader<T> cast = (ObjectReader<T>) o; + return cast.readObject(this); + } catch(ClassCastException e) { + throw new FormatException(); + } + } } diff --git a/components/net/sf/briar/serial/WriterImpl.java b/components/net/sf/briar/serial/WriterImpl.java index 5f5f06de8e..9f7360075a 100644 --- a/components/net/sf/briar/serial/WriterImpl.java +++ b/components/net/sf/briar/serial/WriterImpl.java @@ -2,12 +2,14 @@ package net.sf.briar.serial; import java.io.IOException; import java.io.OutputStream; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Map.Entry; import net.sf.briar.api.serial.Raw; import net.sf.briar.api.serial.Tag; +import net.sf.briar.api.serial.Writable; import net.sf.briar.api.serial.Writer; class WriterImpl implements Writer { @@ -146,19 +148,20 @@ class WriterImpl implements Writer { writeRaw(r.getBytes()); } - public void writeList(List<?> l) throws IOException { - int length = l.size(); + public void writeList(Collection<?> c) throws IOException { + int length = c.size(); if(length < 16) out.write(intToByte(Tag.SHORT_LIST | length)); else { out.write(Tag.LIST); writeLength(length); } - for(Object o : l) writeObject(o); + for(Object o : c) writeObject(o); bytesWritten++; } private void writeObject(Object o) throws IOException { - if(o instanceof Boolean) writeBoolean((Boolean) o); + if(o instanceof Writable) ((Writable) o).writeTo(this); + else if(o instanceof Boolean) writeBoolean((Boolean) o); else if(o instanceof Byte) writeIntAny((Byte) o); else if(o instanceof Short) writeIntAny((Short) o); else if(o instanceof Integer) writeIntAny((Integer) o); diff --git a/test/build.xml b/test/build.xml index ed32b0c66d..f51ac300af 100644 --- a/test/build.xml +++ b/test/build.xml @@ -22,6 +22,7 @@ <test name='net.sf.briar.invitation.InvitationWorkerTest'/> <test name='net.sf.briar.protocol.BundleReadWriteTest'/> <test name='net.sf.briar.protocol.ConsumersTest'/> + <test name='net.sf.briar.protocol.SigningDigestingOutputStreamTest'/> <test name='net.sf.briar.serial.ReaderImplTest'/> <test name='net.sf.briar.serial.WriterImplTest'/> <test name='net.sf.briar.setup.SetupWorkerTest'/> diff --git a/test/net/sf/briar/protocol/BundleReadWriteTest.java b/test/net/sf/briar/protocol/BundleReadWriteTest.java index 8c5245b60c..b44092a817 100644 --- a/test/net/sf/briar/protocol/BundleReadWriteTest.java +++ b/test/net/sf/briar/protocol/BundleReadWriteTest.java @@ -125,7 +125,7 @@ public class BundleReadWriteTest extends TestCase { testWriteBundle(); MessageReader messageReader = - new MessageReaderImpl(keyParser, sig1, dig1); + new MessageReader(keyParser, sig1, dig1); FileInputStream in = new FileInputStream(bundle); Reader reader = rf.createReader(in); BundleReader r = new BundleReaderImpl(reader, keyPair.getPublic(), sig, @@ -158,14 +158,14 @@ public class BundleReadWriteTest extends TestCase { testWriteBundle(); RandomAccessFile f = new RandomAccessFile(bundle, "rw"); - f.seek(bundle.length() - 150); + f.seek(bundle.length() - 100); byte b = f.readByte(); - f.seek(bundle.length() - 150); + f.seek(bundle.length() - 100); f.writeByte(b + 1); f.close(); MessageReader messageReader = - new MessageReaderImpl(keyParser, sig1, dig1); + new MessageReader(keyParser, sig1, dig1); FileInputStream in = new FileInputStream(bundle); Reader reader = rf.createReader(in); BundleReader r = new BundleReaderImpl(reader, keyPair.getPublic(), sig, diff --git a/test/net/sf/briar/protocol/SigningDigestingOutputStreamTest.java b/test/net/sf/briar/protocol/SigningDigestingOutputStreamTest.java new file mode 100644 index 0000000000..7a04769564 --- /dev/null +++ b/test/net/sf/briar/protocol/SigningDigestingOutputStreamTest.java @@ -0,0 +1,81 @@ +package net.sf.briar.protocol; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.MessageDigest; +import java.security.Signature; +import java.util.Arrays; +import java.util.Random; + +import junit.framework.TestCase; + +import org.junit.Before; +import org.junit.Test; + +public class SigningDigestingOutputStreamTest extends TestCase { + + private static final String SIGNATURE_ALGO = "SHA256withRSA"; + private static final String KEY_PAIR_ALGO = "RSA"; + private static final String DIGEST_ALGO = "SHA-256"; + + private KeyPair keyPair = null; + private Signature sig = null; + private MessageDigest dig = null; + + @Before + public void setUp() throws Exception { + keyPair = KeyPairGenerator.getInstance(KEY_PAIR_ALGO).generateKeyPair(); + sig = Signature.getInstance(SIGNATURE_ALGO); + dig = MessageDigest.getInstance(DIGEST_ALGO); + } + + @Test + public void testStopAndStart() throws Exception { + byte[] input = new byte[1024]; + new Random().nextBytes(input); + ByteArrayOutputStream out = new ByteArrayOutputStream(input.length); + SigningDigestingOutputStream s = + new SigningDigestingOutputStream(out, sig, dig); + sig.initSign(keyPair.getPrivate()); + dig.reset(); + // Sign the first 256 bytes, digest all but the last 256 bytes + s.setDigesting(true); + s.setSigning(true); + s.write(input, 0, 256); + s.setSigning(false); + s.write(input, 256, 512); + s.setDigesting(false); + s.write(input, 768, 256); + s.close(); + // Get the signature and the digest + byte[] signature = sig.sign(); + byte[] digest = dig.digest(); + // Check that the output matches the input + assertTrue(Arrays.equals(input, out.toByteArray())); + // Check that the signature matches a signature over the first 256 bytes + sig.initSign(keyPair.getPrivate()); + sig.update(input, 0, 256); + byte[] directSignature = sig.sign(); + assertTrue(Arrays.equals(directSignature, signature)); + // Check that the digest matches a digest over all but the last 256 + // bytes + dig.reset(); + dig.update(input, 0, 768); + byte[] directDigest = dig.digest(); + assertTrue(Arrays.equals(directDigest, digest)); + } + + @Test + public void testSignatureExceptionThrowsIOException() throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + SigningDigestingOutputStream s = + new SigningDigestingOutputStream(out, sig, dig); + s.setSigning(true); // Signature hasn't been initialised yet + try { + s.write((byte) 0); + assertTrue(false); + } catch(IOException expected) {}; + } +} diff --git a/test/net/sf/briar/serial/ReaderImplTest.java b/test/net/sf/briar/serial/ReaderImplTest.java index 9a1281b800..f4e63071c8 100644 --- a/test/net/sf/briar/serial/ReaderImplTest.java +++ b/test/net/sf/briar/serial/ReaderImplTest.java @@ -8,8 +8,10 @@ import java.util.Map; import java.util.Map.Entry; import junit.framework.TestCase; +import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.Raw; import net.sf.briar.api.serial.RawByteArray; +import net.sf.briar.api.serial.Reader; import net.sf.briar.util.StringUtils; import org.junit.Test; @@ -326,6 +328,56 @@ public class ReaderImplTest extends TestCase { assertTrue(r.eof()); } + @Test + public void testReadUserDefinedObject() throws IOException { + setContents("C0" + "83666F6F"); + // Add an object reader for a user-defined type + r.addObjectReader(0, new ObjectReader<Foo>() { + public Foo readObject(Reader r) throws IOException { + return new Foo(r.readString()); + } + }); + assertEquals(0, r.readUserDefinedTag()); + assertEquals("foo", r.<Foo>readUserDefinedObject(0).s); + } + + @Test + public void testReadListUsingObjectReader() throws IOException { + setContents("A" + "1" + "C0" + "83666F6F"); + // Add an object reader for a user-defined type + r.addObjectReader(0, new ObjectReader<Foo>() { + public Foo readObject(Reader r) throws IOException { + return new Foo(r.readString()); + } + }); + // Check that the object reader is used for lists + List<Foo> l = r.readList(Foo.class); + assertEquals(1, l.size()); + assertEquals("foo", l.get(0).s); + } + + @Test + public void testReadMapUsingObjectReader() throws IOException { + setContents("B" + "1" + "C0" + "83666F6F" + "C1" + "83626172"); + // Add object readers for two user-defined types + r.addObjectReader(0, new ObjectReader<Foo>() { + public Foo readObject(Reader r) throws IOException { + return new Foo(r.readString()); + } + }); + r.addObjectReader(1, new ObjectReader<Bar>() { + public Bar readObject(Reader r) throws IOException { + return new Bar(r.readString()); + } + }); + // Check that the object readers are used for maps + Map<Foo, Bar> m = r.readMap(Foo.class, Bar.class); + assertEquals(1, m.size()); + Entry<Foo, Bar> e = m.entrySet().iterator().next(); + assertEquals("foo", e.getKey().s); + assertEquals("bar", e.getValue().s); + } + @Test public void testReadEmptyInput() throws IOException { setContents(""); @@ -336,4 +388,22 @@ public class ReaderImplTest extends TestCase { in = new ByteArrayInputStream(StringUtils.fromHexString(hex)); r = new ReaderImpl(in); } + + private static class Foo { + + private final String s; + + private Foo(String s) { + this.s = s; + } + } + + private static class Bar { + + private final String s; + + private Bar(String s) { + this.s = s; + } + } } diff --git a/test/net/sf/briar/serial/WriterImplTest.java b/test/net/sf/briar/serial/WriterImplTest.java index c45983d123..6bc4b7c2f5 100644 --- a/test/net/sf/briar/serial/WriterImplTest.java +++ b/test/net/sf/briar/serial/WriterImplTest.java @@ -4,12 +4,15 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import junit.framework.TestCase; import net.sf.briar.api.serial.RawByteArray; +import net.sf.briar.api.serial.Writable; +import net.sf.briar.api.serial.Writer; import net.sf.briar.util.StringUtils; import org.junit.Before; @@ -300,6 +303,20 @@ public class WriterImplTest extends TestCase { checkContents("EF" + "20" + "EF" + "FB7FFFFFFF"); } + @Test + public void testWriteCollectionOfWritables() throws IOException { + Writable writable = new Writable() { + public void writeTo(Writer w) throws IOException { + w.writeUserDefinedTag(0); + w.writeString("foo"); + } + }; + w.writeList(Collections.singleton(writable)); + // SHORT_LIST tag, length 1, SHORT_USER tag (3 bits), 0 (5 bits), + // "foo" as short string + checkContents("A" + "1" + "C0" + "83666F6F"); + } + private void checkContents(String hex) throws IOException { out.flush(); out.close(); -- GitLab