From e24a3218caeac42f5c0035ccde2fb4c3d4952628 Mon Sep 17 00:00:00 2001 From: akwizgran <akwizgran@users.sourceforge.net> Date: Mon, 5 Dec 2011 22:52:00 +0000 Subject: [PATCH] Moved message verification and DB writes off the IO thread. --- .../sf/briar/api/protocol/ProtocolReader.java | 2 +- .../briar/api/protocol/UnverifiedBatch.java | 8 ++ .../net/sf/briar/protocol/BatchFactory.java | 12 --- .../sf/briar/protocol/BatchFactoryImpl.java | 14 ---- .../net/sf/briar/protocol/BatchReader.java | 20 ++--- .../net/sf/briar/protocol/MessageReader.java | 55 ++---------- .../net/sf/briar/protocol/ProtocolModule.java | 18 ++-- .../protocol/ProtocolReaderFactoryImpl.java | 6 +- .../sf/briar/protocol/ProtocolReaderImpl.java | 11 +-- .../protocol/UnverifiedBatchFactory.java | 12 +++ .../protocol/UnverifiedBatchFactoryImpl.java | 24 ++++++ .../briar/protocol/UnverifiedBatchImpl.java | 83 +++++++++++++++++++ .../sf/briar/protocol/UnverifiedMessage.java | 32 +++++++ .../briar/protocol/UnverifiedMessageImpl.java | 82 ++++++++++++++++++ .../batch/BatchConnectionFactoryImpl.java | 10 ++- .../batch/IncomingBatchConnection.java | 73 ++++++++++++---- .../stream/IncomingStreamConnection.java | 9 +- .../stream/OutgoingStreamConnection.java | 8 +- .../transport/stream/StreamConnection.java | 83 ++++++++++++++++--- .../stream/StreamConnectionFactoryImpl.java | 11 ++- .../net/sf/briar/ProtocolIntegrationTest.java | 4 +- .../sf/briar/protocol/BatchReaderTest.java | 54 ++++++------ .../briar/protocol/ProtocolReadWriteTest.java | 2 +- .../batch/BatchConnectionReadWriteTest.java | 4 +- 24 files changed, 468 insertions(+), 169 deletions(-) create mode 100644 api/net/sf/briar/api/protocol/UnverifiedBatch.java delete mode 100644 components/net/sf/briar/protocol/BatchFactory.java delete mode 100644 components/net/sf/briar/protocol/BatchFactoryImpl.java create mode 100644 components/net/sf/briar/protocol/UnverifiedBatchFactory.java create mode 100644 components/net/sf/briar/protocol/UnverifiedBatchFactoryImpl.java create mode 100644 components/net/sf/briar/protocol/UnverifiedBatchImpl.java create mode 100644 components/net/sf/briar/protocol/UnverifiedMessage.java create mode 100644 components/net/sf/briar/protocol/UnverifiedMessageImpl.java diff --git a/api/net/sf/briar/api/protocol/ProtocolReader.java b/api/net/sf/briar/api/protocol/ProtocolReader.java index 781eccfba0..104b0d60a5 100644 --- a/api/net/sf/briar/api/protocol/ProtocolReader.java +++ b/api/net/sf/briar/api/protocol/ProtocolReader.java @@ -10,7 +10,7 @@ public interface ProtocolReader { Ack readAck() throws IOException; boolean hasBatch() throws IOException; - Batch readBatch() throws IOException; + UnverifiedBatch readBatch() throws IOException; boolean hasOffer() throws IOException; Offer readOffer() throws IOException; diff --git a/api/net/sf/briar/api/protocol/UnverifiedBatch.java b/api/net/sf/briar/api/protocol/UnverifiedBatch.java new file mode 100644 index 0000000000..abd1f15159 --- /dev/null +++ b/api/net/sf/briar/api/protocol/UnverifiedBatch.java @@ -0,0 +1,8 @@ +package net.sf.briar.api.protocol; + +import java.security.GeneralSecurityException; + +public interface UnverifiedBatch { + + Batch verify() throws GeneralSecurityException; +} diff --git a/components/net/sf/briar/protocol/BatchFactory.java b/components/net/sf/briar/protocol/BatchFactory.java deleted file mode 100644 index 5bbffdbcd0..0000000000 --- a/components/net/sf/briar/protocol/BatchFactory.java +++ /dev/null @@ -1,12 +0,0 @@ -package net.sf.briar.protocol; - -import java.util.Collection; - -import net.sf.briar.api.protocol.Batch; -import net.sf.briar.api.protocol.BatchId; -import net.sf.briar.api.protocol.Message; - -interface BatchFactory { - - Batch createBatch(BatchId id, Collection<Message> messages); -} diff --git a/components/net/sf/briar/protocol/BatchFactoryImpl.java b/components/net/sf/briar/protocol/BatchFactoryImpl.java deleted file mode 100644 index a21bef6585..0000000000 --- a/components/net/sf/briar/protocol/BatchFactoryImpl.java +++ /dev/null @@ -1,14 +0,0 @@ -package net.sf.briar.protocol; - -import java.util.Collection; - -import net.sf.briar.api.protocol.Batch; -import net.sf.briar.api.protocol.BatchId; -import net.sf.briar.api.protocol.Message; - -class BatchFactoryImpl implements BatchFactory { - - public Batch createBatch(BatchId id, Collection<Message> messages) { - return new BatchImpl(id, messages); - } -} diff --git a/components/net/sf/briar/protocol/BatchReader.java b/components/net/sf/briar/protocol/BatchReader.java index 164c5e4606..1f67b3d1b0 100644 --- a/components/net/sf/briar/protocol/BatchReader.java +++ b/components/net/sf/briar/protocol/BatchReader.java @@ -5,31 +5,31 @@ import java.util.List; import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.MessageDigest; -import net.sf.briar.api.protocol.Batch; import net.sf.briar.api.protocol.BatchId; -import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Types; +import net.sf.briar.api.protocol.UnverifiedBatch; import net.sf.briar.api.serial.Consumer; import net.sf.briar.api.serial.CountingConsumer; import net.sf.briar.api.serial.DigestingConsumer; import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.Reader; -class BatchReader implements ObjectReader<Batch> { +class BatchReader implements ObjectReader<UnverifiedBatch> { private final MessageDigest messageDigest; - private final ObjectReader<Message> messageReader; - private final BatchFactory batchFactory; + private final ObjectReader<UnverifiedMessage> messageReader; + private final UnverifiedBatchFactory batchFactory; - BatchReader(CryptoComponent crypto, ObjectReader<Message> messageReader, - BatchFactory batchFactory) { + BatchReader(CryptoComponent crypto, + ObjectReader<UnverifiedMessage> messageReader, + UnverifiedBatchFactory batchFactory) { messageDigest = crypto.getMessageDigest(); this.messageReader = messageReader; this.batchFactory = batchFactory; } - public Batch readObject(Reader r) throws IOException { + public UnverifiedBatch readObject(Reader r) throws IOException { // Initialise the consumers Consumer counting = new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH); @@ -40,12 +40,12 @@ class BatchReader implements ObjectReader<Batch> { r.addConsumer(digesting); r.readStructId(Types.BATCH); r.addObjectReader(Types.MESSAGE, messageReader); - List<Message> messages = r.readList(Message.class); + List<UnverifiedMessage> messages = r.readList(UnverifiedMessage.class); r.removeObjectReader(Types.MESSAGE); r.removeConsumer(digesting); r.removeConsumer(counting); // Build and return the batch BatchId id = new BatchId(messageDigest.digest()); - return batchFactory.createBatch(id, messages); + return batchFactory.createUnverifiedBatch(id, messages); } } diff --git a/components/net/sf/briar/protocol/MessageReader.java b/components/net/sf/briar/protocol/MessageReader.java index d86a7bfcb9..0409daf0b0 100644 --- a/components/net/sf/briar/protocol/MessageReader.java +++ b/components/net/sf/briar/protocol/MessageReader.java @@ -1,19 +1,10 @@ package net.sf.briar.protocol; import java.io.IOException; -import java.security.GeneralSecurityException; -import java.security.PublicKey; -import java.security.Signature; import net.sf.briar.api.FormatException; -import net.sf.briar.api.crypto.CryptoComponent; -import net.sf.briar.api.crypto.KeyParser; -import net.sf.briar.api.crypto.MessageDigest; import net.sf.briar.api.protocol.Author; -import net.sf.briar.api.protocol.AuthorId; import net.sf.briar.api.protocol.Group; -import net.sf.briar.api.protocol.GroupId; -import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Types; @@ -22,28 +13,21 @@ import net.sf.briar.api.serial.CountingConsumer; import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.Reader; -class MessageReader implements ObjectReader<Message> { +class MessageReader implements ObjectReader<UnverifiedMessage> { private final ObjectReader<MessageId> messageIdReader; private final ObjectReader<Group> groupReader; private final ObjectReader<Author> authorReader; - private final KeyParser keyParser; - private final Signature signature; - private final MessageDigest messageDigest; - MessageReader(CryptoComponent crypto, - ObjectReader<MessageId> messageIdReader, + MessageReader(ObjectReader<MessageId> messageIdReader, ObjectReader<Group> groupReader, ObjectReader<Author> authorReader) { this.messageIdReader = messageIdReader; this.groupReader = groupReader; this.authorReader = authorReader; - keyParser = crypto.getKeyParser(); - signature = crypto.getSignature(); - messageDigest = crypto.getMessageDigest(); } - public Message readObject(Reader r) throws IOException { + public UnverifiedMessage readObject(Reader r) throws IOException { CopyingConsumer copying = new CopyingConsumer(); CountingConsumer counting = new CountingConsumer(ProtocolConstants.MAX_PACKET_LENGTH); @@ -106,35 +90,8 @@ class MessageReader implements ObjectReader<Message> { r.removeConsumer(counting); r.removeConsumer(copying); byte[] raw = copying.getCopy(); - // Verify the author's signature, if there is one - if(author != null) { - try { - PublicKey k = keyParser.parsePublicKey(author.getPublicKey()); - signature.initVerify(k); - signature.update(raw, 0, signedByAuthor); - if(!signature.verify(authorSig)) throw new FormatException(); - } catch(GeneralSecurityException e) { - throw new FormatException(); - } - } - // Verify the group's signature, if there is one - if(group != null && group.getPublicKey() != null) { - try { - PublicKey k = keyParser.parsePublicKey(group.getPublicKey()); - signature.initVerify(k); - signature.update(raw, 0, signedByGroup); - if(!signature.verify(groupSig)) throw new FormatException(); - } catch(GeneralSecurityException e) { - throw new FormatException(); - } - } - // Hash the message, including the signatures, to get the message ID - messageDigest.reset(); - messageDigest.update(raw); - MessageId id = new MessageId(messageDigest.digest()); - GroupId groupId = group == null ? null : group.getId(); - AuthorId authorId = author == null ? null : author.getId(); - return new MessageImpl(id, parent, groupId, authorId, subject, - timestamp, raw, bodyStart, body.length); + return new UnverifiedMessageImpl(parent, group, author, subject, + timestamp, raw, authorSig, groupSig, bodyStart, body.length, + signedByAuthor, signedByGroup); } } diff --git a/components/net/sf/briar/protocol/ProtocolModule.java b/components/net/sf/briar/protocol/ProtocolModule.java index 99d3f0b642..77ab27ed09 100644 --- a/components/net/sf/briar/protocol/ProtocolModule.java +++ b/components/net/sf/briar/protocol/ProtocolModule.java @@ -4,10 +4,8 @@ import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Author; import net.sf.briar.api.protocol.AuthorFactory; -import net.sf.briar.api.protocol.Batch; import net.sf.briar.api.protocol.Group; import net.sf.briar.api.protocol.GroupFactory; -import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.MessageFactory; import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.Offer; @@ -15,6 +13,7 @@ import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.TransportUpdate; +import net.sf.briar.api.protocol.UnverifiedBatch; import net.sf.briar.api.serial.ObjectReader; import com.google.inject.AbstractModule; @@ -26,14 +25,15 @@ public class ProtocolModule extends AbstractModule { protected void configure() { bind(AckFactory.class).to(AckFactoryImpl.class); bind(AuthorFactory.class).to(AuthorFactoryImpl.class); - bind(BatchFactory.class).to(BatchFactoryImpl.class); bind(GroupFactory.class).to(GroupFactoryImpl.class); bind(MessageFactory.class).to(MessageFactoryImpl.class); bind(OfferFactory.class).to(OfferFactoryImpl.class); bind(ProtocolReaderFactory.class).to(ProtocolReaderFactoryImpl.class); bind(RequestFactory.class).to(RequestFactoryImpl.class); - bind(SubscriptionUpdateFactory.class).to(SubscriptionUpdateFactoryImpl.class); + bind(SubscriptionUpdateFactory.class).to( + SubscriptionUpdateFactoryImpl.class); bind(TransportUpdateFactory.class).to(TransportUpdateFactoryImpl.class); + bind(UnverifiedBatchFactory.class).to(UnverifiedBatchFactoryImpl.class); } @Provides @@ -48,8 +48,9 @@ public class ProtocolModule extends AbstractModule { } @Provides - ObjectReader<Batch> getBatchReader(CryptoComponent crypto, - ObjectReader<Message> messageReader, BatchFactory batchFactory) { + ObjectReader<UnverifiedBatch> getBatchReader(CryptoComponent crypto, + ObjectReader<UnverifiedMessage> messageReader, + UnverifiedBatchFactory batchFactory) { return new BatchReader(crypto, messageReader, batchFactory); } @@ -65,12 +66,11 @@ public class ProtocolModule extends AbstractModule { } @Provides - ObjectReader<Message> getMessageReader(CryptoComponent crypto, + ObjectReader<UnverifiedMessage> getMessageReader( ObjectReader<MessageId> messageIdReader, ObjectReader<Group> groupReader, ObjectReader<Author> authorReader) { - return new MessageReader(crypto, messageIdReader, groupReader, - authorReader); + return new MessageReader(messageIdReader, groupReader, authorReader); } @Provides diff --git a/components/net/sf/briar/protocol/ProtocolReaderFactoryImpl.java b/components/net/sf/briar/protocol/ProtocolReaderFactoryImpl.java index d118880501..e74d6f2aed 100644 --- a/components/net/sf/briar/protocol/ProtocolReaderFactoryImpl.java +++ b/components/net/sf/briar/protocol/ProtocolReaderFactoryImpl.java @@ -3,13 +3,13 @@ package net.sf.briar.protocol; import java.io.InputStream; import net.sf.briar.api.protocol.Ack; -import net.sf.briar.api.protocol.Batch; import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.ProtocolReader; import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.TransportUpdate; +import net.sf.briar.api.protocol.UnverifiedBatch; import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.ReaderFactory; @@ -20,7 +20,7 @@ class ProtocolReaderFactoryImpl implements ProtocolReaderFactory { private final ReaderFactory readerFactory; private final Provider<ObjectReader<Ack>> ackProvider; - private final Provider<ObjectReader<Batch>> batchProvider; + private final Provider<ObjectReader<UnverifiedBatch>> batchProvider; private final Provider<ObjectReader<Offer>> offerProvider; private final Provider<ObjectReader<Request>> requestProvider; private final Provider<ObjectReader<SubscriptionUpdate>> subscriptionProvider; @@ -29,7 +29,7 @@ class ProtocolReaderFactoryImpl implements ProtocolReaderFactory { @Inject ProtocolReaderFactoryImpl(ReaderFactory readerFactory, Provider<ObjectReader<Ack>> ackProvider, - Provider<ObjectReader<Batch>> batchProvider, + Provider<ObjectReader<UnverifiedBatch>> batchProvider, Provider<ObjectReader<Offer>> offerProvider, Provider<ObjectReader<Request>> requestProvider, Provider<ObjectReader<SubscriptionUpdate>> subscriptionProvider, diff --git a/components/net/sf/briar/protocol/ProtocolReaderImpl.java b/components/net/sf/briar/protocol/ProtocolReaderImpl.java index 893b437e4e..a556361341 100644 --- a/components/net/sf/briar/protocol/ProtocolReaderImpl.java +++ b/components/net/sf/briar/protocol/ProtocolReaderImpl.java @@ -4,13 +4,13 @@ import java.io.IOException; import java.io.InputStream; import net.sf.briar.api.protocol.Ack; -import net.sf.briar.api.protocol.Batch; import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.ProtocolReader; import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.SubscriptionUpdate; -import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.TransportUpdate; +import net.sf.briar.api.protocol.Types; +import net.sf.briar.api.protocol.UnverifiedBatch; import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.ReaderFactory; @@ -20,7 +20,8 @@ class ProtocolReaderImpl implements ProtocolReader { private final Reader reader; ProtocolReaderImpl(InputStream in, ReaderFactory readerFactory, - ObjectReader<Ack> ackReader, ObjectReader<Batch> batchReader, + ObjectReader<Ack> ackReader, + ObjectReader<UnverifiedBatch> batchReader, ObjectReader<Offer> offerReader, ObjectReader<Request> requestReader, ObjectReader<SubscriptionUpdate> subscriptionReader, @@ -50,8 +51,8 @@ class ProtocolReaderImpl implements ProtocolReader { return reader.hasStruct(Types.BATCH); } - public Batch readBatch() throws IOException { - return reader.readStruct(Types.BATCH, Batch.class); + public UnverifiedBatch readBatch() throws IOException { + return reader.readStruct(Types.BATCH, UnverifiedBatch.class); } public boolean hasOffer() throws IOException { diff --git a/components/net/sf/briar/protocol/UnverifiedBatchFactory.java b/components/net/sf/briar/protocol/UnverifiedBatchFactory.java new file mode 100644 index 0000000000..c420dd6b68 --- /dev/null +++ b/components/net/sf/briar/protocol/UnverifiedBatchFactory.java @@ -0,0 +1,12 @@ +package net.sf.briar.protocol; + +import java.util.Collection; + +import net.sf.briar.api.protocol.BatchId; +import net.sf.briar.api.protocol.UnverifiedBatch; + +interface UnverifiedBatchFactory { + + UnverifiedBatch createUnverifiedBatch(BatchId id, + Collection<UnverifiedMessage> messages); +} diff --git a/components/net/sf/briar/protocol/UnverifiedBatchFactoryImpl.java b/components/net/sf/briar/protocol/UnverifiedBatchFactoryImpl.java new file mode 100644 index 0000000000..b762a278d1 --- /dev/null +++ b/components/net/sf/briar/protocol/UnverifiedBatchFactoryImpl.java @@ -0,0 +1,24 @@ +package net.sf.briar.protocol; + +import java.util.Collection; + +import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.api.protocol.BatchId; +import net.sf.briar.api.protocol.UnverifiedBatch; + +import com.google.inject.Inject; + +class UnverifiedBatchFactoryImpl implements UnverifiedBatchFactory { + + private final CryptoComponent crypto; + + @Inject + UnverifiedBatchFactoryImpl(CryptoComponent crypto) { + this.crypto = crypto; + } + + public UnverifiedBatch createUnverifiedBatch(BatchId id, + Collection<UnverifiedMessage> messages) { + return new UnverifiedBatchImpl(crypto, id, messages); + } +} diff --git a/components/net/sf/briar/protocol/UnverifiedBatchImpl.java b/components/net/sf/briar/protocol/UnverifiedBatchImpl.java new file mode 100644 index 0000000000..8665de30a4 --- /dev/null +++ b/components/net/sf/briar/protocol/UnverifiedBatchImpl.java @@ -0,0 +1,83 @@ +package net.sf.briar.protocol; + +import java.security.GeneralSecurityException; +import java.security.PublicKey; +import java.security.Signature; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.api.crypto.KeyParser; +import net.sf.briar.api.crypto.MessageDigest; +import net.sf.briar.api.protocol.Author; +import net.sf.briar.api.protocol.AuthorId; +import net.sf.briar.api.protocol.Batch; +import net.sf.briar.api.protocol.BatchId; +import net.sf.briar.api.protocol.Group; +import net.sf.briar.api.protocol.GroupId; +import net.sf.briar.api.protocol.Message; +import net.sf.briar.api.protocol.MessageId; +import net.sf.briar.api.protocol.UnverifiedBatch; + +class UnverifiedBatchImpl implements UnverifiedBatch { + + private final CryptoComponent crypto; + private final BatchId id; + private final Collection<UnverifiedMessage> messages; + + // Initialise lazily - the batch may be empty or contain unsigned messages + private MessageDigest messageDigest = null; + private KeyParser keyParser = null; + private Signature signature = null; + + UnverifiedBatchImpl(CryptoComponent crypto, BatchId id, + Collection<UnverifiedMessage> messages) { + this.crypto = crypto; + this.id = id; + this.messages = messages; + } + + public Batch verify() throws GeneralSecurityException { + List<Message> verified = new ArrayList<Message>(); + for(UnverifiedMessage m : messages) verified.add(verify(m)); + return new BatchImpl(id, Collections.unmodifiableList(verified)); + } + + private Message verify(UnverifiedMessage m) + throws GeneralSecurityException { + // Hash the message, including the signatures, to get the message ID + byte[] raw = m.getRaw(); + if(messageDigest == null) messageDigest = crypto.getMessageDigest(); + messageDigest.update(raw); + MessageId id = new MessageId(messageDigest.digest()); + // Verify the author's signature, if there is one + Author author = m.getAuthor(); + if(author != null) { + if(keyParser == null) keyParser = crypto.getKeyParser(); + PublicKey k = keyParser.parsePublicKey(author.getPublicKey()); + if(signature == null) signature = crypto.getSignature(); + signature.initVerify(k); + signature.update(raw, 0, m.getLengthSignedByAuthor()); + if(!signature.verify(m.getAuthorSignature())) + throw new GeneralSecurityException(); + } + // Verify the group's signature, if there is one + Group group = m.getGroup(); + if(group != null && group.getPublicKey() != null) { + if(keyParser == null) keyParser = crypto.getKeyParser(); + PublicKey k = keyParser.parsePublicKey(group.getPublicKey()); + if(signature == null) signature = crypto.getSignature(); + signature.initVerify(k); + signature.update(raw, 0, m.getLengthSignedByGroup()); + if(!signature.verify(m.getGroupSignature())) + throw new GeneralSecurityException(); + } + GroupId groupId = group == null ? null : group.getId(); + AuthorId authorId = author == null ? null : author.getId(); + return new MessageImpl(id, m.getParent(), groupId, authorId, + m.getSubject(), m.getTimestamp(), raw, m.getBodyStart(), + m.getBodyLength()); + } +} diff --git a/components/net/sf/briar/protocol/UnverifiedMessage.java b/components/net/sf/briar/protocol/UnverifiedMessage.java new file mode 100644 index 0000000000..64866d2a01 --- /dev/null +++ b/components/net/sf/briar/protocol/UnverifiedMessage.java @@ -0,0 +1,32 @@ +package net.sf.briar.protocol; + +import net.sf.briar.api.protocol.Author; +import net.sf.briar.api.protocol.Group; +import net.sf.briar.api.protocol.MessageId; + +interface UnverifiedMessage { + + MessageId getParent(); + + Group getGroup(); + + Author getAuthor(); + + String getSubject(); + + long getTimestamp(); + + byte[] getRaw(); + + byte[] getAuthorSignature(); + + byte[] getGroupSignature(); + + int getBodyStart(); + + int getBodyLength(); + + int getLengthSignedByAuthor(); + + int getLengthSignedByGroup(); +} \ No newline at end of file diff --git a/components/net/sf/briar/protocol/UnverifiedMessageImpl.java b/components/net/sf/briar/protocol/UnverifiedMessageImpl.java new file mode 100644 index 0000000000..dd3f941c62 --- /dev/null +++ b/components/net/sf/briar/protocol/UnverifiedMessageImpl.java @@ -0,0 +1,82 @@ +package net.sf.briar.protocol; + +import net.sf.briar.api.protocol.Author; +import net.sf.briar.api.protocol.Group; +import net.sf.briar.api.protocol.MessageId; + +class UnverifiedMessageImpl implements UnverifiedMessage { + + private final MessageId parent; + private final Group group; + private final Author author; + private final String subject; + private final long timestamp; + private final byte[] raw, authorSig, groupSig; + private final int bodyStart, bodyLength, signedByAuthor, signedByGroup; + + UnverifiedMessageImpl(MessageId parent, Group group, Author author, + String subject, long timestamp, byte[] raw, byte[] authorSig, + byte[] groupSig, int bodyStart, int bodyLength, int signedByAuthor, + int signedByGroup) { + this.parent = parent; + this.group = group; + this.author = author; + this.subject = subject; + this.timestamp = timestamp; + this.raw = raw; + this.authorSig = authorSig; + this.groupSig = groupSig; + this.bodyStart = bodyStart; + this.bodyLength = bodyLength; + this.signedByAuthor = signedByAuthor; + this.signedByGroup = signedByGroup; + } + + public MessageId getParent() { + return parent; + } + + public Group getGroup() { + return group; + } + + public Author getAuthor() { + return author; + } + + public String getSubject() { + return subject; + } + + public long getTimestamp() { + return timestamp; + } + + public byte[] getRaw() { + return raw; + } + + public byte[] getAuthorSignature() { + return authorSig; + } + + public byte[] getGroupSignature() { + return groupSig; + } + + public int getBodyStart() { + return bodyStart; + } + + public int getBodyLength() { + return bodyLength; + } + + public int getLengthSignedByAuthor() { + return signedByAuthor; + } + + public int getLengthSignedByGroup() { + return signedByGroup; + } +} diff --git a/components/net/sf/briar/transport/batch/BatchConnectionFactoryImpl.java b/components/net/sf/briar/transport/batch/BatchConnectionFactoryImpl.java index da49fc6aec..81a2152357 100644 --- a/components/net/sf/briar/transport/batch/BatchConnectionFactoryImpl.java +++ b/components/net/sf/briar/transport/batch/BatchConnectionFactoryImpl.java @@ -1,5 +1,7 @@ package net.sf.briar.transport.batch; +import java.util.concurrent.Executor; + import net.sf.briar.api.ContactId; import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.protocol.ProtocolReaderFactory; @@ -16,6 +18,7 @@ import com.google.inject.Inject; class BatchConnectionFactoryImpl implements BatchConnectionFactory { + private final Executor executor; private final ConnectionReaderFactory connReaderFactory; private final ConnectionWriterFactory connWriterFactory; private final DatabaseComponent db; @@ -23,10 +26,12 @@ class BatchConnectionFactoryImpl implements BatchConnectionFactory { private final ProtocolWriterFactory protoWriterFactory; @Inject - BatchConnectionFactoryImpl(ConnectionReaderFactory connReaderFactory, + BatchConnectionFactoryImpl(Executor executor, + ConnectionReaderFactory connReaderFactory, ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ProtocolReaderFactory protoReaderFactory, ProtocolWriterFactory protoWriterFactory) { + this.executor = executor; this.connReaderFactory = connReaderFactory; this.connWriterFactory = connWriterFactory; this.db = db; @@ -37,7 +42,8 @@ class BatchConnectionFactoryImpl implements BatchConnectionFactory { public void createIncomingConnection(ConnectionContext ctx, BatchTransportReader r, byte[] tag) { final IncomingBatchConnection conn = new IncomingBatchConnection( - connReaderFactory, db, protoReaderFactory, ctx, r, tag); + executor, connReaderFactory, db, protoReaderFactory, ctx, r, + tag); Runnable read = new Runnable() { public void run() { conn.read(); diff --git a/components/net/sf/briar/transport/batch/IncomingBatchConnection.java b/components/net/sf/briar/transport/batch/IncomingBatchConnection.java index fdcc60541f..c8d63980e6 100644 --- a/components/net/sf/briar/transport/batch/IncomingBatchConnection.java +++ b/components/net/sf/briar/transport/batch/IncomingBatchConnection.java @@ -1,6 +1,8 @@ package net.sf.briar.transport.batch; import java.io.IOException; +import java.security.GeneralSecurityException; +import java.util.concurrent.Executor; import java.util.logging.Level; import java.util.logging.Logger; @@ -9,11 +11,11 @@ import net.sf.briar.api.FormatException; import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DbException; import net.sf.briar.api.protocol.Ack; -import net.sf.briar.api.protocol.Batch; import net.sf.briar.api.protocol.ProtocolReader; import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.TransportUpdate; +import net.sf.briar.api.protocol.UnverifiedBatch; import net.sf.briar.api.transport.BatchTransportReader; import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionReader; @@ -24,6 +26,7 @@ class IncomingBatchConnection { private static final Logger LOG = Logger.getLogger(IncomingBatchConnection.class.getName()); + private final Executor executor; private final ConnectionReaderFactory connFactory; private final DatabaseComponent db; private final ProtocolReaderFactory protoFactory; @@ -31,9 +34,11 @@ class IncomingBatchConnection { private final BatchTransportReader reader; private final byte[] tag; - IncomingBatchConnection(ConnectionReaderFactory connFactory, + IncomingBatchConnection(Executor executor, + ConnectionReaderFactory connFactory, DatabaseComponent db, ProtocolReaderFactory protoFactory, ConnectionContext ctx, BatchTransportReader reader, byte[] tag) { + this.executor = executor; this.connFactory = connFactory; this.db = db; this.protoFactory = protoFactory; @@ -48,28 +53,68 @@ class IncomingBatchConnection { reader.getInputStream(), ctx.getSecret(), tag); ProtocolReader proto = protoFactory.createProtocolReader( conn.getInputStream()); - ContactId c = ctx.getContactId(); + final ContactId c = ctx.getContactId(); // Read packets until EOF while(!proto.eof()) { if(proto.hasAck()) { - Ack a = proto.readAck(); - db.receiveAck(c, a); + final Ack a = proto.readAck(); + // Store the ack on another thread + executor.execute(new Runnable() { + public void run() { + try { + db.receiveAck(c, a); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) + LOG.warning(e.getMessage()); + } + } + }); } else if(proto.hasBatch()) { - Batch b = proto.readBatch(); - db.receiveBatch(c, b); + final UnverifiedBatch b = proto.readBatch(); + // Verify and store the batch on another thread + executor.execute(new Runnable() { + public void run() { + try { + db.receiveBatch(c, b.verify()); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) + LOG.warning(e.getMessage()); + } catch(GeneralSecurityException e) { + if(LOG.isLoggable(Level.WARNING)) + LOG.warning(e.getMessage()); + } + } + }); } else if(proto.hasSubscriptionUpdate()) { - SubscriptionUpdate s = proto.readSubscriptionUpdate(); - db.receiveSubscriptionUpdate(c, s); + final SubscriptionUpdate s = proto.readSubscriptionUpdate(); + // Store the update on another thread + executor.execute(new Runnable() { + public void run() { + try { + db.receiveSubscriptionUpdate(c, s); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) + LOG.warning(e.getMessage()); + } + } + }); } else if(proto.hasTransportUpdate()) { - TransportUpdate t = proto.readTransportUpdate(); - db.receiveTransportUpdate(c, t); + final TransportUpdate t = proto.readTransportUpdate(); + // Store the update on another thread + executor.execute(new Runnable() { + public void run() { + try { + db.receiveTransportUpdate(c, t); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) + LOG.warning(e.getMessage()); + } + } + }); } else { throw new FormatException(); } } - } catch(DbException e) { - if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); - reader.dispose(false); } catch(IOException e) { if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); reader.dispose(false); diff --git a/components/net/sf/briar/transport/stream/IncomingStreamConnection.java b/components/net/sf/briar/transport/stream/IncomingStreamConnection.java index 16da6c7952..006659e103 100644 --- a/components/net/sf/briar/transport/stream/IncomingStreamConnection.java +++ b/components/net/sf/briar/transport/stream/IncomingStreamConnection.java @@ -1,6 +1,7 @@ package net.sf.briar.transport.stream; import java.io.IOException; +import java.util.concurrent.Executor; import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DbException; @@ -18,14 +19,16 @@ class IncomingStreamConnection extends StreamConnection { private final ConnectionContext ctx; private final byte[] tag; - IncomingStreamConnection(ConnectionReaderFactory connReaderFactory, + IncomingStreamConnection(Executor executor, + ConnectionReaderFactory connReaderFactory, ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ProtocolReaderFactory protoReaderFactory, ProtocolWriterFactory protoWriterFactory, ConnectionContext ctx, StreamTransportConnection connection, byte[] tag) { - super(connReaderFactory, connWriterFactory, db, protoReaderFactory, - protoWriterFactory, ctx.getContactId(), connection); + super(executor, connReaderFactory, connWriterFactory, db, + protoReaderFactory, protoWriterFactory, ctx.getContactId(), + connection); this.ctx = ctx; this.tag = tag; } diff --git a/components/net/sf/briar/transport/stream/OutgoingStreamConnection.java b/components/net/sf/briar/transport/stream/OutgoingStreamConnection.java index 64d61e1306..5f1dabbbbc 100644 --- a/components/net/sf/briar/transport/stream/OutgoingStreamConnection.java +++ b/components/net/sf/briar/transport/stream/OutgoingStreamConnection.java @@ -1,6 +1,7 @@ package net.sf.briar.transport.stream; import java.io.IOException; +import java.util.concurrent.Executor; import net.sf.briar.api.ContactId; import net.sf.briar.api.db.DatabaseComponent; @@ -21,14 +22,15 @@ class OutgoingStreamConnection extends StreamConnection { private ConnectionContext ctx = null; // Locking: this - OutgoingStreamConnection(ConnectionReaderFactory connReaderFactory, + OutgoingStreamConnection(Executor executor, + ConnectionReaderFactory connReaderFactory, ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ProtocolReaderFactory protoReaderFactory, ProtocolWriterFactory protoWriterFactory, ContactId contactId, TransportIndex transportIndex, StreamTransportConnection connection) { - super(connReaderFactory, connWriterFactory, db, protoReaderFactory, - protoWriterFactory, contactId, connection); + super(executor, connReaderFactory, connWriterFactory, db, + protoReaderFactory, protoWriterFactory, contactId, connection); this.transportIndex = transportIndex; } diff --git a/components/net/sf/briar/transport/stream/StreamConnection.java b/components/net/sf/briar/transport/stream/StreamConnection.java index b43ada49f1..071868574c 100644 --- a/components/net/sf/briar/transport/stream/StreamConnection.java +++ b/components/net/sf/briar/transport/stream/StreamConnection.java @@ -3,12 +3,14 @@ package net.sf.briar.transport.stream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.security.GeneralSecurityException; import java.util.ArrayList; import java.util.BitSet; import java.util.Collection; import java.util.Collections; import java.util.LinkedList; import java.util.List; +import java.util.concurrent.Executor; import java.util.logging.Level; import java.util.logging.Logger; @@ -24,7 +26,6 @@ import net.sf.briar.api.db.event.LocalTransportsUpdatedEvent; import net.sf.briar.api.db.event.MessagesAddedEvent; import net.sf.briar.api.db.event.SubscriptionsUpdatedEvent; import net.sf.briar.api.protocol.Ack; -import net.sf.briar.api.protocol.Batch; import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.ProtocolReader; @@ -32,6 +33,7 @@ import net.sf.briar.api.protocol.ProtocolReaderFactory; import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.TransportUpdate; +import net.sf.briar.api.protocol.UnverifiedBatch; import net.sf.briar.api.protocol.writers.AckWriter; import net.sf.briar.api.protocol.writers.BatchWriter; import net.sf.briar.api.protocol.writers.OfferWriter; @@ -52,6 +54,7 @@ abstract class StreamConnection implements DatabaseListener { private static final Logger LOG = Logger.getLogger(StreamConnection.class.getName()); + protected final Executor executor; protected final ConnectionReaderFactory connReaderFactory; protected final ConnectionWriterFactory connWriterFactory; protected final DatabaseComponent db; @@ -65,11 +68,13 @@ abstract class StreamConnection implements DatabaseListener { private LinkedList<MessageId> requested = null; // Locking: this private Offer incomingOffer = null; // Locking: this - StreamConnection(ConnectionReaderFactory connReaderFactory, + StreamConnection(Executor executor, + ConnectionReaderFactory connReaderFactory, ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ProtocolReaderFactory protoReaderFactory, ProtocolWriterFactory protoWriterFactory, ContactId contactId, StreamTransportConnection connection) { + this.executor = executor; this.connReaderFactory = connReaderFactory; this.connWriterFactory = connWriterFactory; this.db = db; @@ -119,11 +124,34 @@ abstract class StreamConnection implements DatabaseListener { ProtocolReader proto = protoReaderFactory.createProtocolReader(in); while(!proto.eof()) { if(proto.hasAck()) { - Ack a = proto.readAck(); - db.receiveAck(contactId, a); + final Ack a = proto.readAck(); + // Store the ack on another thread + executor.execute(new Runnable() { + public void run() { + try { + db.receiveAck(contactId, a); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) + LOG.warning(e.getMessage()); + } + } + }); } else if(proto.hasBatch()) { - Batch b = proto.readBatch(); - db.receiveBatch(contactId, b); + final UnverifiedBatch b = proto.readBatch(); + // Verify and store the batch on another thread + executor.execute(new Runnable() { + public void run() { + try { + db.receiveBatch(contactId, b.verify()); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) + LOG.warning(e.getMessage()); + } catch(GeneralSecurityException e) { + if(LOG.isLoggable(Level.WARNING)) + LOG.warning(e.getMessage()); + } + } + }); } else if(proto.hasOffer()) { Offer o = proto.readOffer(); // Store the incoming offer and notify the writer @@ -151,8 +179,19 @@ abstract class StreamConnection implements DatabaseListener { if(b.get(i++)) req.add(m); else seen.add(m); } - // Mark the unrequested messages as seen - db.setSeen(contactId, Collections.unmodifiableList(seen)); + // Mark the unrequested messages as seen on another thread + final List<MessageId> l = + Collections.unmodifiableList(seen); + executor.execute(new Runnable() { + public void run() { + try { + db.setSeen(contactId, l); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) + LOG.warning(e.getMessage()); + } + } + }); // Store the requested message IDs and notify the writer synchronized(this) { if(requested != null) @@ -162,11 +201,31 @@ abstract class StreamConnection implements DatabaseListener { notifyAll(); } } else if(proto.hasSubscriptionUpdate()) { - SubscriptionUpdate s = proto.readSubscriptionUpdate(); - db.receiveSubscriptionUpdate(contactId, s); + final SubscriptionUpdate s = proto.readSubscriptionUpdate(); + // Store the update on another thread + executor.execute(new Runnable() { + public void run() { + try { + db.receiveSubscriptionUpdate(contactId, s); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) + LOG.warning(e.getMessage()); + } + } + }); } else if(proto.hasTransportUpdate()) { - TransportUpdate t = proto.readTransportUpdate(); - db.receiveTransportUpdate(contactId, t); + final TransportUpdate t = proto.readTransportUpdate(); + // Store the update on another thread + executor.execute(new Runnable() { + public void run() { + try { + db.receiveTransportUpdate(contactId, t); + } catch(DbException e) { + if(LOG.isLoggable(Level.WARNING)) + LOG.warning(e.getMessage()); + } + } + }); } else { throw new FormatException(); } diff --git a/components/net/sf/briar/transport/stream/StreamConnectionFactoryImpl.java b/components/net/sf/briar/transport/stream/StreamConnectionFactoryImpl.java index b5651a0911..659d293aee 100644 --- a/components/net/sf/briar/transport/stream/StreamConnectionFactoryImpl.java +++ b/components/net/sf/briar/transport/stream/StreamConnectionFactoryImpl.java @@ -1,5 +1,7 @@ package net.sf.briar.transport.stream; +import java.util.concurrent.Executor; + import net.sf.briar.api.ContactId; import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.protocol.ProtocolReaderFactory; @@ -15,6 +17,7 @@ import com.google.inject.Inject; class StreamConnectionFactoryImpl implements StreamConnectionFactory { + private final Executor executor; private final ConnectionReaderFactory connReaderFactory; private final ConnectionWriterFactory connWriterFactory; private final DatabaseComponent db; @@ -22,10 +25,12 @@ class StreamConnectionFactoryImpl implements StreamConnectionFactory { private final ProtocolWriterFactory protoWriterFactory; @Inject - StreamConnectionFactoryImpl(ConnectionReaderFactory connReaderFactory, + StreamConnectionFactoryImpl(Executor executor, + ConnectionReaderFactory connReaderFactory, ConnectionWriterFactory connWriterFactory, DatabaseComponent db, ProtocolReaderFactory protoReaderFactory, ProtocolWriterFactory protoWriterFactory) { + this.executor = executor; this.connReaderFactory = connReaderFactory; this.connWriterFactory = connWriterFactory; this.db = db; @@ -35,7 +40,7 @@ class StreamConnectionFactoryImpl implements StreamConnectionFactory { public void createIncomingConnection(ConnectionContext ctx, StreamTransportConnection s, byte[] tag) { - final StreamConnection conn = new IncomingStreamConnection( + final StreamConnection conn = new IncomingStreamConnection(executor, connReaderFactory, connWriterFactory, db, protoReaderFactory, protoWriterFactory, ctx, s, tag); Runnable write = new Runnable() { @@ -54,7 +59,7 @@ class StreamConnectionFactoryImpl implements StreamConnectionFactory { public void createOutgoingConnection(ContactId c, TransportIndex i, StreamTransportConnection s) { - final StreamConnection conn = new OutgoingStreamConnection( + final StreamConnection conn = new OutgoingStreamConnection(executor, connReaderFactory, connWriterFactory, db, protoReaderFactory, protoWriterFactory, c, i, s); Runnable write = new Runnable() { diff --git a/test/net/sf/briar/ProtocolIntegrationTest.java b/test/net/sf/briar/ProtocolIntegrationTest.java index c83a081d6a..7d5434a625 100644 --- a/test/net/sf/briar/ProtocolIntegrationTest.java +++ b/test/net/sf/briar/ProtocolIntegrationTest.java @@ -208,9 +208,9 @@ public class ProtocolIntegrationTest extends TestCase { Ack a = protocolReader.readAck(); assertEquals(Collections.singletonList(ack), a.getBatchIds()); - // Read the batch + // Read and verify the batch assertTrue(protocolReader.hasBatch()); - Batch b = protocolReader.readBatch(); + Batch b = protocolReader.readBatch().verify(); Collection<Message> messages = b.getMessages(); assertEquals(4, messages.size()); Iterator<Message> it = messages.iterator(); diff --git a/test/net/sf/briar/protocol/BatchReaderTest.java b/test/net/sf/briar/protocol/BatchReaderTest.java index f2971e511b..8f8898fd5d 100644 --- a/test/net/sf/briar/protocol/BatchReaderTest.java +++ b/test/net/sf/briar/protocol/BatchReaderTest.java @@ -9,11 +9,10 @@ import junit.framework.TestCase; import net.sf.briar.api.FormatException; import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.crypto.MessageDigest; -import net.sf.briar.api.protocol.Batch; import net.sf.briar.api.protocol.BatchId; -import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Types; +import net.sf.briar.api.protocol.UnverifiedBatch; import net.sf.briar.api.serial.ObjectReader; import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.ReaderFactory; @@ -35,7 +34,8 @@ public class BatchReaderTest extends TestCase { private final WriterFactory writerFactory; private final CryptoComponent crypto; private final Mockery context; - private final Message message; + private final UnverifiedMessage message; + private final ObjectReader<UnverifiedMessage> messageReader; public BatchReaderTest() throws Exception { super(); @@ -45,13 +45,14 @@ public class BatchReaderTest extends TestCase { writerFactory = i.getInstance(WriterFactory.class); crypto = i.getInstance(CryptoComponent.class); context = new Mockery(); - message = context.mock(Message.class); + message = context.mock(UnverifiedMessage.class); + messageReader = new TestMessageReader(); } @Test public void testFormatExceptionIfBatchIsTooLarge() throws Exception { - ObjectReader<Message> messageReader = new TestMessageReader(); - BatchFactory batchFactory = context.mock(BatchFactory.class); + UnverifiedBatchFactory batchFactory = + context.mock(UnverifiedBatchFactory.class); BatchReader batchReader = new BatchReader(crypto, messageReader, batchFactory); @@ -61,7 +62,7 @@ public class BatchReaderTest extends TestCase { reader.addObjectReader(Types.BATCH, batchReader); try { - reader.readStruct(Types.BATCH, Batch.class); + reader.readStruct(Types.BATCH, UnverifiedBatch.class); fail(); } catch(FormatException expected) {} context.assertIsSatisfied(); @@ -69,13 +70,13 @@ public class BatchReaderTest extends TestCase { @Test public void testNoFormatExceptionIfBatchIsMaximumSize() throws Exception { - ObjectReader<Message> messageReader = new TestMessageReader(); - final BatchFactory batchFactory = context.mock(BatchFactory.class); + final UnverifiedBatchFactory batchFactory = + context.mock(UnverifiedBatchFactory.class); BatchReader batchReader = new BatchReader(crypto, messageReader, batchFactory); - final Batch batch = context.mock(Batch.class); + final UnverifiedBatch batch = context.mock(UnverifiedBatch.class); context.checking(new Expectations() {{ - oneOf(batchFactory).createBatch(with(any(BatchId.class)), + oneOf(batchFactory).createUnverifiedBatch(with(any(BatchId.class)), with(Collections.singletonList(message))); will(returnValue(batch)); }}); @@ -85,7 +86,8 @@ public class BatchReaderTest extends TestCase { Reader reader = readerFactory.createReader(in); reader.addObjectReader(Types.BATCH, batchReader); - assertEquals(batch, reader.readStruct(Types.BATCH, Batch.class)); + assertEquals(batch, reader.readStruct(Types.BATCH, + UnverifiedBatch.class)); context.assertIsSatisfied(); } @@ -98,14 +100,14 @@ public class BatchReaderTest extends TestCase { messageDigest.update(b); final BatchId id = new BatchId(messageDigest.digest()); - ObjectReader<Message> messageReader = new TestMessageReader(); - final BatchFactory batchFactory = context.mock(BatchFactory.class); + final UnverifiedBatchFactory batchFactory = + context.mock(UnverifiedBatchFactory.class); BatchReader batchReader = new BatchReader(crypto, messageReader, batchFactory); - final Batch batch = context.mock(Batch.class); + final UnverifiedBatch batch = context.mock(UnverifiedBatch.class); context.checking(new Expectations() {{ // Check that the batch ID matches the expected ID - oneOf(batchFactory).createBatch(with(id), + oneOf(batchFactory).createUnverifiedBatch(with(id), with(Collections.singletonList(message))); will(returnValue(batch)); }}); @@ -114,20 +116,21 @@ public class BatchReaderTest extends TestCase { Reader reader = readerFactory.createReader(in); reader.addObjectReader(Types.BATCH, batchReader); - assertEquals(batch, reader.readStruct(Types.BATCH, Batch.class)); + assertEquals(batch, reader.readStruct(Types.BATCH, + UnverifiedBatch.class)); context.assertIsSatisfied(); } @Test public void testEmptyBatch() throws Exception { - ObjectReader<Message> messageReader = new TestMessageReader(); - final BatchFactory batchFactory = context.mock(BatchFactory.class); + final UnverifiedBatchFactory batchFactory = + context.mock(UnverifiedBatchFactory.class); BatchReader batchReader = new BatchReader(crypto, messageReader, batchFactory); - final Batch batch = context.mock(Batch.class); + final UnverifiedBatch batch = context.mock(UnverifiedBatch.class); context.checking(new Expectations() {{ - oneOf(batchFactory).createBatch(with(any(BatchId.class)), - with(Collections.<Message>emptyList())); + oneOf(batchFactory).createUnverifiedBatch(with(any(BatchId.class)), + with(Collections.<UnverifiedMessage>emptyList())); will(returnValue(batch)); }}); @@ -136,7 +139,8 @@ public class BatchReaderTest extends TestCase { Reader reader = readerFactory.createReader(in); reader.addObjectReader(Types.BATCH, batchReader); - assertEquals(batch, reader.readStruct(Types.BATCH, Batch.class)); + assertEquals(batch, reader.readStruct(Types.BATCH, + UnverifiedBatch.class)); context.assertIsSatisfied(); } @@ -163,9 +167,9 @@ public class BatchReaderTest extends TestCase { return out.toByteArray(); } - private class TestMessageReader implements ObjectReader<Message> { + private class TestMessageReader implements ObjectReader<UnverifiedMessage> { - public Message readObject(Reader r) throws IOException { + public UnverifiedMessage readObject(Reader r) throws IOException { r.readStructId(Types.MESSAGE); r.readBytes(); return message; diff --git a/test/net/sf/briar/protocol/ProtocolReadWriteTest.java b/test/net/sf/briar/protocol/ProtocolReadWriteTest.java index 96efbbba58..121b25e4b6 100644 --- a/test/net/sf/briar/protocol/ProtocolReadWriteTest.java +++ b/test/net/sf/briar/protocol/ProtocolReadWriteTest.java @@ -114,7 +114,7 @@ public class ProtocolReadWriteTest extends TestCase { Ack ack = reader.readAck(); assertEquals(Collections.singletonList(batchId), ack.getBatchIds()); - Batch batch = reader.readBatch(); + Batch batch = reader.readBatch().verify(); assertEquals(Collections.singletonList(message), batch.getMessages()); Offer offer = reader.readOffer(); diff --git a/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java b/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java index 1e7ff71255..a25a9aeb5f 100644 --- a/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java +++ b/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java @@ -41,6 +41,7 @@ import net.sf.briar.api.transport.ConnectionWriterFactory; import net.sf.briar.crypto.CryptoModule; import net.sf.briar.db.DatabaseModule; import net.sf.briar.lifecycle.LifecycleModule; +import net.sf.briar.plugins.ImmediateExecutor; import net.sf.briar.protocol.ProtocolModule; import net.sf.briar.protocol.writers.ProtocolWritersModule; import net.sf.briar.serial.SerialModule; @@ -187,7 +188,8 @@ public class BatchConnectionReadWriteTest extends TestCase { bob.getInstance(ProtocolReaderFactory.class); BatchTransportReader reader = new TestBatchTransportReader(in); IncomingBatchConnection batchIn = new IncomingBatchConnection( - connFactory, db, protoFactory, ctx, reader, tag); + new ImmediateExecutor(), connFactory, db, protoFactory, ctx, + reader, tag); // No messages should have been added yet assertFalse(listener.messagesAdded); // Read whatever needs to be read -- GitLab