diff --git a/briar-api/src/org/briarproject/api/clients/MessageQueueManager.java b/briar-api/src/org/briarproject/api/clients/MessageQueueManager.java new file mode 100644 index 0000000000000000000000000000000000000000..52bb298ac466347bd54e600e8c2271244273a41d --- /dev/null +++ b/briar-api/src/org/briarproject/api/clients/MessageQueueManager.java @@ -0,0 +1,52 @@ +package org.briarproject.api.clients; + +import org.briarproject.api.db.DbException; +import org.briarproject.api.db.Metadata; +import org.briarproject.api.db.Transaction; +import org.briarproject.api.sync.ClientId; +import org.briarproject.api.sync.Group; + +public interface MessageQueueManager { + + /** + * The key used for storing the queue's state in the group metadata. + */ + String QUEUE_STATE_KEY = "queueState"; + + /** + * Sends a message using the given queue. + */ + QueueMessage sendMessage(Transaction txn, Group queue, long timestamp, + byte[] body, Metadata meta) throws DbException; + + /** + * Sets the message validator for the given client. + */ + void registerMessageValidator(ClientId c, QueueMessageValidator v); + + /** + * Sets the incoming message hook for the given client. The hook will be + * called once for each incoming message that passes validation. Messages + * are passed to the hook in order. + */ + void registerIncomingMessageHook(ClientId c, IncomingQueueMessageHook hook); + + interface QueueMessageValidator { + + /** + * Validates the given message and returns its metadata if the message + * is valid, or null if the message is invalid. + */ + Metadata validateMessage(QueueMessage q, Group g); + } + + interface IncomingQueueMessageHook { + + /** + * Called once for each incoming message that passes validation. + * Messages are passed to the hook in order. + */ + void incomingMessage(Transaction txn, QueueMessage q, Metadata meta) + throws DbException; + } +} diff --git a/briar-api/src/org/briarproject/api/clients/QueueMessage.java b/briar-api/src/org/briarproject/api/clients/QueueMessage.java new file mode 100644 index 0000000000000000000000000000000000000000..0979c3d19493d663292cbda98020b53ee3cb3ebe --- /dev/null +++ b/briar-api/src/org/briarproject/api/clients/QueueMessage.java @@ -0,0 +1,28 @@ +package org.briarproject.api.clients; + +import org.briarproject.api.sync.GroupId; +import org.briarproject.api.sync.Message; +import org.briarproject.api.sync.MessageId; + +import static org.briarproject.api.sync.SyncConstants.MAX_MESSAGE_BODY_LENGTH; +import static org.briarproject.api.sync.SyncConstants.MESSAGE_HEADER_LENGTH; + +public class QueueMessage extends Message { + + public static final int QUEUE_MESSAGE_HEADER_LENGTH = + MESSAGE_HEADER_LENGTH + 8; + public static final int MAX_QUEUE_MESSAGE_BODY_LENGTH = + MAX_MESSAGE_BODY_LENGTH - 8; + + private final long queuePosition; + + public QueueMessage(MessageId id, GroupId groupId, long timestamp, + long queuePosition, byte[] raw) { + super(id, groupId, timestamp, raw); + this.queuePosition = queuePosition; + } + + public long getQueuePosition() { + return queuePosition; + } +} diff --git a/briar-api/src/org/briarproject/api/clients/QueueMessageFactory.java b/briar-api/src/org/briarproject/api/clients/QueueMessageFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..cea02f230f7c9b3a42657a4eb80c894e086d597d --- /dev/null +++ b/briar-api/src/org/briarproject/api/clients/QueueMessageFactory.java @@ -0,0 +1,12 @@ +package org.briarproject.api.clients; + +import org.briarproject.api.sync.GroupId; +import org.briarproject.api.sync.MessageId; + +public interface QueueMessageFactory { + + QueueMessage createMessage(GroupId groupId, long timestamp, + long queuePosition, byte[] body); + + QueueMessage createMessage(MessageId id, byte[] raw); +} diff --git a/briar-api/src/org/briarproject/api/sync/MessageValidator.java b/briar-api/src/org/briarproject/api/sync/MessageValidator.java deleted file mode 100644 index 58ee5dc2b96b980d0dea03a85a50951b0f23e4c5..0000000000000000000000000000000000000000 --- a/briar-api/src/org/briarproject/api/sync/MessageValidator.java +++ /dev/null @@ -1,12 +0,0 @@ -package org.briarproject.api.sync; - -import org.briarproject.api.db.Metadata; - -public interface MessageValidator { - - /** - * Validates the given message and returns its metadata if the message - * is valid, or null if the message is invalid. - */ - Metadata validateMessage(Message m, Group g); -} diff --git a/briar-api/src/org/briarproject/api/sync/ValidationManager.java b/briar-api/src/org/briarproject/api/sync/ValidationManager.java index 690bc54e23f9fa110a7721b8e5f6284fad39752f..abea827e2b2f7ba2eccd0cab10844598000f2b33 100644 --- a/briar-api/src/org/briarproject/api/sync/ValidationManager.java +++ b/briar-api/src/org/briarproject/api/sync/ValidationManager.java @@ -30,14 +30,32 @@ public interface ValidationManager { } } - /** Sets the message validator for the given client. */ + /** + * Sets the message validator for the given client. + */ void registerMessageValidator(ClientId c, MessageValidator v); - /** Registers a hook to be called whenever a message is validated. */ - void registerValidationHook(ValidationHook hook); + /** + * Sets the incoming message hook for the given client. The hook will be + * called once for each incoming message that passes validation. + */ + void registerIncomingMessageHook(ClientId c, IncomingMessageHook hook); - interface ValidationHook { - void validatingMessage(Transaction txn, Message m, ClientId c, - Metadata meta) throws DbException; + interface MessageValidator { + + /** + * Validates the given message and returns its metadata if the message + * is valid, or null if the message is invalid. + */ + Metadata validateMessage(Message m, Group g); + } + + interface IncomingMessageHook { + + /** + * Called once for each incoming message that passes validation. + */ + void incomingMessage(Transaction txn, Message m, Metadata meta) + throws DbException; } } diff --git a/briar-core/src/org/briarproject/clients/BdfIncomingMessageHook.java b/briar-core/src/org/briarproject/clients/BdfIncomingMessageHook.java new file mode 100644 index 0000000000000000000000000000000000000000..5380652c6d84dccb25ef4cab8c9a2ea837d5144d --- /dev/null +++ b/briar-core/src/org/briarproject/clients/BdfIncomingMessageHook.java @@ -0,0 +1,60 @@ +package org.briarproject.clients; + +import org.briarproject.api.FormatException; +import org.briarproject.api.clients.ClientHelper; +import org.briarproject.api.clients.MessageQueueManager.IncomingQueueMessageHook; +import org.briarproject.api.clients.QueueMessage; +import org.briarproject.api.data.BdfDictionary; +import org.briarproject.api.data.BdfList; +import org.briarproject.api.data.MetadataParser; +import org.briarproject.api.db.DbException; +import org.briarproject.api.db.Metadata; +import org.briarproject.api.db.Transaction; +import org.briarproject.api.sync.Message; +import org.briarproject.api.sync.ValidationManager.IncomingMessageHook; +import org.briarproject.api.system.Clock; + +import static org.briarproject.api.clients.QueueMessage.QUEUE_MESSAGE_HEADER_LENGTH; +import static org.briarproject.api.sync.SyncConstants.MESSAGE_HEADER_LENGTH; + +public abstract class BdfIncomingMessageHook implements IncomingMessageHook, + IncomingQueueMessageHook { + + protected final ClientHelper clientHelper; + protected final MetadataParser metadataParser; + + protected BdfIncomingMessageHook(ClientHelper clientHelper, + MetadataParser metadataParser, Clock clock) { + this.clientHelper = clientHelper; + this.metadataParser = metadataParser; + } + + protected abstract void incomingMessage(Transaction txn, Message m, + BdfList body, BdfDictionary meta) throws DbException, + FormatException; + + @Override + public void incomingMessage(Transaction txn, Message m, Metadata meta) + throws DbException { + incomingMessage(txn, m, meta, MESSAGE_HEADER_LENGTH); + } + + @Override + public void incomingMessage(Transaction txn, QueueMessage q, Metadata meta) + throws DbException { + incomingMessage(txn, q, meta, QUEUE_MESSAGE_HEADER_LENGTH); + } + + private void incomingMessage(Transaction txn, Message m, Metadata meta, + int headerLength) throws DbException { + try { + byte[] raw = m.getRaw(); + BdfList body = clientHelper.toList(raw, headerLength, + raw.length - headerLength); + BdfDictionary metaDictionary = metadataParser.parse(meta); + incomingMessage(txn, m, body, metaDictionary); + } catch (FormatException e) { + throw new DbException(e); + } + } +} diff --git a/briar-core/src/org/briarproject/clients/BdfMessageValidator.java b/briar-core/src/org/briarproject/clients/BdfMessageValidator.java index fd170aa7c89d70e5e754a85924df59c593e5debe..19bc55db000f1dd841167c6f3eb4d0fd03a514e5 100644 --- a/briar-core/src/org/briarproject/clients/BdfMessageValidator.java +++ b/briar-core/src/org/briarproject/clients/BdfMessageValidator.java @@ -2,22 +2,26 @@ package org.briarproject.clients; import org.briarproject.api.FormatException; import org.briarproject.api.clients.ClientHelper; +import org.briarproject.api.clients.MessageQueueManager.QueueMessageValidator; +import org.briarproject.api.clients.QueueMessage; import org.briarproject.api.data.BdfDictionary; import org.briarproject.api.data.BdfList; import org.briarproject.api.data.MetadataEncoder; import org.briarproject.api.db.Metadata; import org.briarproject.api.sync.Group; import org.briarproject.api.sync.Message; -import org.briarproject.api.sync.MessageValidator; +import org.briarproject.api.sync.ValidationManager.MessageValidator; import org.briarproject.api.system.Clock; import org.briarproject.util.StringUtils; import java.util.logging.Logger; +import static org.briarproject.api.clients.QueueMessage.QUEUE_MESSAGE_HEADER_LENGTH; import static org.briarproject.api.sync.SyncConstants.MESSAGE_HEADER_LENGTH; import static org.briarproject.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE; -public abstract class BdfMessageValidator implements MessageValidator { +public abstract class BdfMessageValidator implements MessageValidator, + QueueMessageValidator { protected static final Logger LOG = Logger.getLogger(BdfMessageValidator.class.getName()); @@ -33,11 +37,20 @@ public abstract class BdfMessageValidator implements MessageValidator { this.clock = clock; } - protected abstract BdfDictionary validateMessage(BdfList message, Group g, - long timestamp) throws FormatException; + protected abstract BdfDictionary validateMessage(Message m, Group g, + BdfList body) throws FormatException; @Override public Metadata validateMessage(Message m, Group g) { + return validateMessage(m, g, MESSAGE_HEADER_LENGTH); + } + + @Override + public Metadata validateMessage(QueueMessage q, Group g) { + return validateMessage(q, g, QUEUE_MESSAGE_HEADER_LENGTH); + } + + private Metadata validateMessage(Message m, Group g, int headerLength) { // Reject the message if it's too far in the future long now = clock.currentTimeMillis(); if (m.getTimestamp() - now > MAX_CLOCK_DIFFERENCE) { @@ -45,10 +58,14 @@ public abstract class BdfMessageValidator implements MessageValidator { return null; } byte[] raw = m.getRaw(); + if (raw.length <= headerLength) { + LOG.info("Message is too short"); + return null; + } try { - BdfList message = clientHelper.toList(raw, MESSAGE_HEADER_LENGTH, - raw.length - MESSAGE_HEADER_LENGTH); - BdfDictionary meta = validateMessage(message, g, m.getTimestamp()); + BdfList body = clientHelper.toList(raw, headerLength, + raw.length - headerLength); + BdfDictionary meta = validateMessage(m, g, body); if (meta == null) { LOG.info("Invalid message"); return null; @@ -87,7 +104,7 @@ public abstract class BdfMessageValidator implements MessageValidator { } protected void checkSize(BdfList list, int minSize, int maxSize) - throws FormatException { + throws FormatException { if (list != null) { if (list.size() < minSize) throw new FormatException(); if (list.size() > maxSize) throw new FormatException(); diff --git a/briar-core/src/org/briarproject/clients/ClientsModule.java b/briar-core/src/org/briarproject/clients/ClientsModule.java index 4f3a1d4a9fd616b575af179bbc2be331f89c7af6..b4c1a306cee9f43240883584eae7144ed906f365 100644 --- a/briar-core/src/org/briarproject/clients/ClientsModule.java +++ b/briar-core/src/org/briarproject/clients/ClientsModule.java @@ -3,13 +3,17 @@ package org.briarproject.clients; import com.google.inject.AbstractModule; import org.briarproject.api.clients.ClientHelper; +import org.briarproject.api.clients.MessageQueueManager; import org.briarproject.api.clients.PrivateGroupFactory; +import org.briarproject.api.clients.QueueMessageFactory; public class ClientsModule extends AbstractModule { @Override protected void configure() { bind(ClientHelper.class).to(ClientHelperImpl.class); + bind(MessageQueueManager.class).to(MessageQueueManagerImpl.class); bind(PrivateGroupFactory.class).to(PrivateGroupFactoryImpl.class); + bind(QueueMessageFactory.class).to(QueueMessageFactoryImpl.class); } } diff --git a/briar-core/src/org/briarproject/clients/MessageQueueManagerImpl.java b/briar-core/src/org/briarproject/clients/MessageQueueManagerImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..5dadef907e02175ceeb6ed3a3a33bf3811a73d3f --- /dev/null +++ b/briar-core/src/org/briarproject/clients/MessageQueueManagerImpl.java @@ -0,0 +1,221 @@ +package org.briarproject.clients; + +import org.briarproject.api.FormatException; +import org.briarproject.api.clients.ClientHelper; +import org.briarproject.api.clients.MessageQueueManager; +import org.briarproject.api.clients.QueueMessage; +import org.briarproject.api.clients.QueueMessageFactory; +import org.briarproject.api.data.BdfDictionary; +import org.briarproject.api.data.BdfList; +import org.briarproject.api.db.DatabaseComponent; +import org.briarproject.api.db.DbException; +import org.briarproject.api.db.Metadata; +import org.briarproject.api.db.Transaction; +import org.briarproject.api.sync.ClientId; +import org.briarproject.api.sync.Group; +import org.briarproject.api.sync.GroupId; +import org.briarproject.api.sync.Message; +import org.briarproject.api.sync.MessageId; +import org.briarproject.api.sync.ValidationManager; +import org.briarproject.api.sync.ValidationManager.IncomingMessageHook; +import org.briarproject.util.ByteUtils; + +import java.util.Iterator; +import java.util.Map.Entry; +import java.util.TreeMap; +import java.util.logging.Logger; + +import javax.inject.Inject; + +import static java.util.logging.Level.INFO; +import static org.briarproject.api.clients.QueueMessage.QUEUE_MESSAGE_HEADER_LENGTH; +import static org.briarproject.api.sync.SyncConstants.MESSAGE_HEADER_LENGTH; + +class MessageQueueManagerImpl implements MessageQueueManager { + + private static final String OUTGOING_POSITION_KEY = "nextOut"; + private static final String INCOMING_POSITION_KEY = "nextIn"; + private static final String PENDING_MESSAGES_KEY = "pending"; + + private static final Logger LOG = + Logger.getLogger(MessageQueueManagerImpl.class.getName()); + + private final DatabaseComponent db; + private final ClientHelper clientHelper; + private final QueueMessageFactory queueMessageFactory; + private final ValidationManager validationManager; + + @Inject + MessageQueueManagerImpl(DatabaseComponent db, ClientHelper clientHelper, + QueueMessageFactory queueMessageFactory, + ValidationManager validationManager) { + this.db = db; + this.clientHelper = clientHelper; + this.queueMessageFactory = queueMessageFactory; + this.validationManager = validationManager; + } + + @Override + public QueueMessage sendMessage(Transaction txn, Group queue, + long timestamp, byte[] body, Metadata meta) throws DbException { + QueueState queueState = loadQueueState(txn, queue.getId()); + long queuePosition = queueState.outgoingPosition; + queueState.outgoingPosition++; + saveQueueState(txn, queue.getId(), queueState); + QueueMessage q = queueMessageFactory.createMessage(queue.getId(), + timestamp, queuePosition, body); + db.addLocalMessage(txn, q, queue.getClientId(), meta, true); + return q; + } + + @Override + public void registerMessageValidator(ClientId c, QueueMessageValidator v) { + validationManager.registerMessageValidator(c, + new DelegatingMessageValidator(v)); + } + + @Override + public void registerIncomingMessageHook(ClientId c, + IncomingQueueMessageHook hook) { + validationManager.registerIncomingMessageHook(c, + new DelegatingIncomingMessageHook(hook)); + } + + private QueueState loadQueueState(Transaction txn, GroupId g) + throws DbException { + try { + TreeMap<Long, MessageId> pending = new TreeMap<Long, MessageId>(); + Metadata groupMeta = db.getGroupMetadata(txn, g); + byte[] raw = groupMeta.get(QUEUE_STATE_KEY); + if (raw == null) return new QueueState(0, 0, pending); + BdfDictionary d = clientHelper.toDictionary(raw, 0, raw.length); + long outgoingPosition = d.getLong(OUTGOING_POSITION_KEY); + long incomingPosition = d.getLong(INCOMING_POSITION_KEY); + BdfList pendingList = d.getList(PENDING_MESSAGES_KEY); + for (int i = 0; i < pendingList.size(); i++) { + BdfList item = pendingList.getList(i); + if (item.size() != 2) throw new FormatException(); + pending.put(item.getLong(0), new MessageId(item.getRaw(1))); + } + return new QueueState(outgoingPosition, incomingPosition, pending); + } catch (FormatException e) { + throw new DbException(e); + } + } + + private void saveQueueState(Transaction txn, GroupId g, + QueueState queueState) throws DbException { + try { + BdfDictionary d = new BdfDictionary(); + d.put(OUTGOING_POSITION_KEY, queueState.outgoingPosition); + d.put(INCOMING_POSITION_KEY, queueState.incomingPosition); + BdfList pendingList = new BdfList(); + for (Entry<Long, MessageId> e : queueState.pending.entrySet()) + pendingList.add(BdfList.of(e.getKey(), e.getValue())); + d.put(PENDING_MESSAGES_KEY, pendingList); + Metadata groupMeta = new Metadata(); + groupMeta.put(QUEUE_STATE_KEY, clientHelper.toByteArray(d)); + db.mergeGroupMetadata(txn, g, groupMeta); + } catch (FormatException e) { + throw new RuntimeException(e); + } + } + + private static class QueueState { + + private long outgoingPosition, incomingPosition; + private final TreeMap<Long, MessageId> pending; + + QueueState(long outgoingPosition, long incomingPosition, + TreeMap<Long, MessageId> pending) { + this.outgoingPosition = outgoingPosition; + this.incomingPosition = incomingPosition; + this.pending = pending; + } + + MessageId popIncomingMessageId() { + Iterator<Entry<Long, MessageId>> it = pending.entrySet().iterator(); + if (!it.hasNext()) return null; + Entry<Long, MessageId> e = it.next(); + if (!e.getKey().equals(incomingPosition)) return null; + it.remove(); + incomingPosition++; + return e.getValue(); + } + } + + private static class DelegatingMessageValidator + implements ValidationManager.MessageValidator { + + private final QueueMessageValidator delegate; + + DelegatingMessageValidator(QueueMessageValidator delegate) { + this.delegate = delegate; + } + + @Override + public Metadata validateMessage(Message m, Group g) { + byte[] raw = m.getRaw(); + if (raw.length < QUEUE_MESSAGE_HEADER_LENGTH) return null; + long queuePosition = ByteUtils.readUint64(raw, + MESSAGE_HEADER_LENGTH); + if (queuePosition < 0) return null; + QueueMessage q = new QueueMessage(m.getId(), m.getGroupId(), + m.getTimestamp(), queuePosition, raw); + return delegate.validateMessage(q, g); + } + } + + private class DelegatingIncomingMessageHook implements IncomingMessageHook { + + private final IncomingQueueMessageHook delegate; + + DelegatingIncomingMessageHook(IncomingQueueMessageHook delegate) { + this.delegate = delegate; + } + + @Override + public void incomingMessage(Transaction txn, Message m, Metadata meta) + throws DbException { + long queuePosition = ByteUtils.readUint64(m.getRaw(), + MESSAGE_HEADER_LENGTH); + QueueState queueState = loadQueueState(txn, m.getGroupId()); + if (LOG.isLoggable(INFO)) { + LOG.info("Received message with position " + + queuePosition + ", expecting " + + queueState.incomingPosition); + } + if (queuePosition < queueState.incomingPosition) { + // A message with this queue position has already been seen + LOG.warning("Deleting message with duplicate position"); + db.deleteMessage(txn, m.getId()); + db.deleteMessageMetadata(txn, m.getId()); + } else if (queuePosition > queueState.incomingPosition) { + // The message is out of order, add it to the pending list + LOG.info("Message is out of order, adding to pending list"); + queueState.pending.put(queuePosition, m.getId()); + saveQueueState(txn, m.getGroupId(), queueState); + } else { + // The message is in order, pass it to the delegate + LOG.info("Message is in order, delivering"); + QueueMessage q = new QueueMessage(m.getId(), m.getGroupId(), + m.getTimestamp(), queuePosition, m.getRaw()); + delegate.incomingMessage(txn, q, meta); + queueState.incomingPosition++; + // Pass any consecutive messages to the delegate + MessageId id; + while ((id = queueState.popIncomingMessageId()) != null) { + byte[] raw = db.getRawMessage(txn, id); + meta = db.getMessageMetadata(txn, id); + q = queueMessageFactory.createMessage(id, raw); + if (LOG.isLoggable(INFO)) { + LOG.info("Delivering pending message with position " + + q.getQueuePosition()); + } + delegate.incomingMessage(txn, q, meta); + } + saveQueueState(txn, m.getGroupId(), queueState); + } + } + } +} diff --git a/briar-core/src/org/briarproject/clients/PrivateGroupFactoryImpl.java b/briar-core/src/org/briarproject/clients/PrivateGroupFactoryImpl.java index c1842e6ba6230dca5ede95fd579c8c72d7b274de..5fd4ee98ae535ec8a3d14f1e886f6fa7c8a12c03 100644 --- a/briar-core/src/org/briarproject/clients/PrivateGroupFactoryImpl.java +++ b/briar-core/src/org/briarproject/clients/PrivateGroupFactoryImpl.java @@ -3,28 +3,26 @@ package org.briarproject.clients; import com.google.inject.Inject; import org.briarproject.api.Bytes; +import org.briarproject.api.FormatException; +import org.briarproject.api.clients.ClientHelper; import org.briarproject.api.clients.PrivateGroupFactory; import org.briarproject.api.contact.Contact; -import org.briarproject.api.data.BdfWriter; -import org.briarproject.api.data.BdfWriterFactory; +import org.briarproject.api.data.BdfList; import org.briarproject.api.identity.AuthorId; import org.briarproject.api.sync.ClientId; import org.briarproject.api.sync.Group; import org.briarproject.api.sync.GroupFactory; -import java.io.ByteArrayOutputStream; -import java.io.IOException; - class PrivateGroupFactoryImpl implements PrivateGroupFactory { private final GroupFactory groupFactory; - private final BdfWriterFactory bdfWriterFactory; + private final ClientHelper clientHelper; @Inject PrivateGroupFactoryImpl(GroupFactory groupFactory, - BdfWriterFactory bdfWriterFactory) { + ClientHelper clientHelper) { this.groupFactory = groupFactory; - this.bdfWriterFactory = bdfWriterFactory; + this.clientHelper = clientHelper; } @Override @@ -36,22 +34,12 @@ class PrivateGroupFactoryImpl implements PrivateGroupFactory { } private byte[] createGroupDescriptor(AuthorId local, AuthorId remote) { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - BdfWriter w = bdfWriterFactory.createWriter(out); try { - w.writeListStart(); - if (Bytes.COMPARATOR.compare(local, remote) < 0) { - w.writeRaw(local.getBytes()); - w.writeRaw(remote.getBytes()); - } else { - w.writeRaw(remote.getBytes()); - w.writeRaw(local.getBytes()); - } - w.writeListEnd(); - } catch (IOException e) { - // Shouldn't happen with ByteArrayOutputStream + if (Bytes.COMPARATOR.compare(local, remote) < 0) + return clientHelper.toByteArray(BdfList.of(local, remote)); + else return clientHelper.toByteArray(BdfList.of(remote, local)); + } catch (FormatException e) { throw new RuntimeException(e); } - return out.toByteArray(); } } diff --git a/briar-core/src/org/briarproject/clients/QueueMessageFactoryImpl.java b/briar-core/src/org/briarproject/clients/QueueMessageFactoryImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..d64cbb5713f548c846941397c659d856ebdef15a --- /dev/null +++ b/briar-core/src/org/briarproject/clients/QueueMessageFactoryImpl.java @@ -0,0 +1,55 @@ +package org.briarproject.clients; + +import org.briarproject.api.UniqueId; +import org.briarproject.api.clients.QueueMessage; +import org.briarproject.api.clients.QueueMessageFactory; +import org.briarproject.api.crypto.CryptoComponent; +import org.briarproject.api.sync.GroupId; +import org.briarproject.api.sync.MessageId; +import org.briarproject.util.ByteUtils; + +import javax.inject.Inject; + +import static org.briarproject.api.clients.QueueMessage.MAX_QUEUE_MESSAGE_BODY_LENGTH; +import static org.briarproject.api.clients.QueueMessage.QUEUE_MESSAGE_HEADER_LENGTH; +import static org.briarproject.api.sync.SyncConstants.MAX_MESSAGE_LENGTH; +import static org.briarproject.api.sync.SyncConstants.MESSAGE_HEADER_LENGTH; + +class QueueMessageFactoryImpl implements QueueMessageFactory { + + private final CryptoComponent crypto; + + @Inject + QueueMessageFactoryImpl(CryptoComponent crypto) { + this.crypto = crypto; + } + + @Override + public QueueMessage createMessage(GroupId groupId, long timestamp, + long queuePosition, byte[] body) { + if (body.length > MAX_QUEUE_MESSAGE_BODY_LENGTH) + throw new IllegalArgumentException(); + byte[] raw = new byte[QUEUE_MESSAGE_HEADER_LENGTH + body.length]; + System.arraycopy(groupId.getBytes(), 0, raw, 0, UniqueId.LENGTH); + ByteUtils.writeUint64(timestamp, raw, UniqueId.LENGTH); + ByteUtils.writeUint64(queuePosition, raw, MESSAGE_HEADER_LENGTH); + System.arraycopy(body, 0, raw, QUEUE_MESSAGE_HEADER_LENGTH, + body.length); + MessageId id = new MessageId(crypto.hash(MessageId.LABEL, raw)); + return new QueueMessage(id, groupId, timestamp, queuePosition, raw); + } + + @Override + public QueueMessage createMessage(MessageId id, byte[] raw) { + if (raw.length < QUEUE_MESSAGE_HEADER_LENGTH) + throw new IllegalArgumentException(); + if (raw.length > MAX_MESSAGE_LENGTH) + throw new IllegalArgumentException(); + byte[] groupId = new byte[UniqueId.LENGTH]; + System.arraycopy(raw, 0, groupId, 0, UniqueId.LENGTH); + long timestamp = ByteUtils.readUint64(raw, UniqueId.LENGTH); + long queuePosition = ByteUtils.readUint64(raw, MESSAGE_HEADER_LENGTH); + return new QueueMessage(id, new GroupId(groupId), timestamp, + queuePosition, raw); + } +} diff --git a/briar-core/src/org/briarproject/forum/ForumListValidator.java b/briar-core/src/org/briarproject/forum/ForumListValidator.java index e4b6e4edc6b3775e57ca3332c092b989f85c1735..9e6947a377f4d12b4d02bff1a6b380a1831798e6 100644 --- a/briar-core/src/org/briarproject/forum/ForumListValidator.java +++ b/briar-core/src/org/briarproject/forum/ForumListValidator.java @@ -6,6 +6,7 @@ import org.briarproject.api.data.BdfDictionary; import org.briarproject.api.data.BdfList; import org.briarproject.api.data.MetadataEncoder; import org.briarproject.api.sync.Group; +import org.briarproject.api.sync.Message; import org.briarproject.api.system.Clock; import org.briarproject.clients.BdfMessageValidator; @@ -20,15 +21,15 @@ class ForumListValidator extends BdfMessageValidator { } @Override - public BdfDictionary validateMessage(BdfList message, Group g, - long timestamp) throws FormatException { + protected BdfDictionary validateMessage(Message m, Group g, + BdfList body) throws FormatException { // Version, forum list - checkSize(message, 2); + checkSize(body, 2); // Version - long version = message.getLong(0); + long version = body.getLong(0); if (version < 0) throw new FormatException(); // Forum list - BdfList forumList = message.getList(1); + BdfList forumList = body.getList(1); for (int i = 0; i < forumList.size(); i++) { BdfList forum = forumList.getList(i); // Name, salt diff --git a/briar-core/src/org/briarproject/forum/ForumModule.java b/briar-core/src/org/briarproject/forum/ForumModule.java index 56c18546f86f56ec10e675d811f45fdee6803e30..19e7ce06f135b2c2917f5a9142c60bafe3bef460 100644 --- a/briar-core/src/org/briarproject/forum/ForumModule.java +++ b/briar-core/src/org/briarproject/forum/ForumModule.java @@ -53,7 +53,8 @@ public class ForumModule extends AbstractModule { ForumSharingManagerImpl forumSharingManager) { contactManager.registerAddContactHook(forumSharingManager); contactManager.registerRemoveContactHook(forumSharingManager); - validationManager.registerValidationHook(forumSharingManager); + validationManager.registerIncomingMessageHook( + ForumSharingManagerImpl.CLIENT_ID, forumSharingManager); return forumSharingManager; } } diff --git a/briar-core/src/org/briarproject/forum/ForumPostValidator.java b/briar-core/src/org/briarproject/forum/ForumPostValidator.java index 8bc2734502bde6e0d654bca17a1356217b19611e..aac7b6265de9f221f248afeabd014770138fc47c 100644 --- a/briar-core/src/org/briarproject/forum/ForumPostValidator.java +++ b/briar-core/src/org/briarproject/forum/ForumPostValidator.java @@ -13,6 +13,7 @@ import org.briarproject.api.data.MetadataEncoder; import org.briarproject.api.identity.Author; import org.briarproject.api.identity.AuthorFactory; import org.briarproject.api.sync.Group; +import org.briarproject.api.sync.Message; import org.briarproject.api.system.Clock; import org.briarproject.clients.BdfMessageValidator; @@ -38,16 +39,16 @@ class ForumPostValidator extends BdfMessageValidator { } @Override - protected BdfDictionary validateMessage(BdfList message, Group g, - long timestamp) throws FormatException { + protected BdfDictionary validateMessage(Message m, Group g, + BdfList body) throws FormatException { // Parent ID, author, content type, forum post body, signature - checkSize(message, 5); + checkSize(body, 5); // Parent ID is optional - byte[] parent = message.getOptionalRaw(0); + byte[] parent = body.getOptionalRaw(0); checkLength(parent, UniqueId.LENGTH); // Author is optional Author author = null; - BdfList authorList = message.getOptionalList(1); + BdfList authorList = body.getOptionalList(1); if (authorList != null) { // Name, public key checkSize(authorList, 2); @@ -58,13 +59,13 @@ class ForumPostValidator extends BdfMessageValidator { author = authorFactory.createAuthor(name, publicKey); } // Content type - String contentType = message.getString(2); + String contentType = body.getString(2); checkLength(contentType, 0, MAX_CONTENT_TYPE_LENGTH); // Forum post body - byte[] body = message.getRaw(3); - checkLength(body, 0, MAX_FORUM_POST_BODY_LENGTH); + byte[] forumPostBody = body.getRaw(3); + checkLength(forumPostBody, 0, MAX_FORUM_POST_BODY_LENGTH); // Signature is optional - byte[] sig = message.getOptionalRaw(4); + byte[] sig = body.getOptionalRaw(4); checkLength(sig, 0, MAX_SIGNATURE_LENGTH); // If there's an author there must be a signature and vice versa if (author != null && sig == null) { @@ -82,7 +83,7 @@ class ForumPostValidator extends BdfMessageValidator { KeyParser keyParser = crypto.getSignatureKeyParser(); PublicKey key = keyParser.parsePublicKey(author.getPublicKey()); // Serialise the data to be signed - BdfList signed = BdfList.of(g.getId(), timestamp, parent, + BdfList signed = BdfList.of(g.getId(), m.getTimestamp(), parent, authorList, contentType, body); // Verify the signature Signature signature = crypto.getSignature(); @@ -99,7 +100,7 @@ class ForumPostValidator extends BdfMessageValidator { } // Return the metadata BdfDictionary meta = new BdfDictionary(); - meta.put("timestamp", timestamp); + meta.put("timestamp", m.getTimestamp()); if (parent != null) meta.put("parent", parent); if (author != null) { BdfDictionary authorMeta = new BdfDictionary(); diff --git a/briar-core/src/org/briarproject/forum/ForumSharingManagerImpl.java b/briar-core/src/org/briarproject/forum/ForumSharingManagerImpl.java index 834cb94ddb8c787baa1882126d68ae14a9c4d7d0..aa083679df926ce7e8e94cd807657998de42cad3 100644 --- a/briar-core/src/org/briarproject/forum/ForumSharingManagerImpl.java +++ b/briar-core/src/org/briarproject/forum/ForumSharingManagerImpl.java @@ -24,7 +24,7 @@ import org.briarproject.api.sync.GroupFactory; import org.briarproject.api.sync.GroupId; import org.briarproject.api.sync.Message; import org.briarproject.api.sync.MessageId; -import org.briarproject.api.sync.ValidationManager.ValidationHook; +import org.briarproject.api.sync.ValidationManager.IncomingMessageHook; import org.briarproject.api.system.Clock; import org.briarproject.util.StringUtils; @@ -44,7 +44,7 @@ import static org.briarproject.api.forum.ForumConstants.MAX_FORUM_NAME_LENGTH; import static org.briarproject.api.sync.SyncConstants.MESSAGE_HEADER_LENGTH; class ForumSharingManagerImpl implements ForumSharingManager, AddContactHook, - RemoveContactHook, ValidationHook { + RemoveContactHook, IncomingMessageHook { static final ClientId CLIENT_ID = new ClientId(StringUtils.fromHexString( "cd11a5d04dccd9e2931d6fc3df456313" @@ -103,15 +103,13 @@ class ForumSharingManagerImpl implements ForumSharingManager, AddContactHook, } @Override - public void validatingMessage(Transaction txn, Message m, ClientId c, - Metadata meta) throws DbException { - if (c.equals(CLIENT_ID)) { - try { - ContactId contactId = getContactId(txn, m.getGroupId()); - setForumVisibility(txn, contactId, getVisibleForums(txn, m)); - } catch (FormatException e) { - throw new DbException(e); - } + public void incomingMessage(Transaction txn, Message m, Metadata meta) + throws DbException { + try { + ContactId contactId = getContactId(txn, m.getGroupId()); + setForumVisibility(txn, contactId, getVisibleForums(txn, m)); + } catch (FormatException e) { + throw new DbException(e); } } diff --git a/briar-core/src/org/briarproject/messaging/PrivateMessageValidator.java b/briar-core/src/org/briarproject/messaging/PrivateMessageValidator.java index 0475da174b0d120de0c38f6628e1d2ec6f0fc10f..550da3f9c77d7ab8eaebc9365105275a35cb23af 100644 --- a/briar-core/src/org/briarproject/messaging/PrivateMessageValidator.java +++ b/briar-core/src/org/briarproject/messaging/PrivateMessageValidator.java @@ -7,6 +7,7 @@ import org.briarproject.api.data.BdfDictionary; import org.briarproject.api.data.BdfList; import org.briarproject.api.data.MetadataEncoder; import org.briarproject.api.sync.Group; +import org.briarproject.api.sync.Message; import org.briarproject.api.system.Clock; import org.briarproject.clients.BdfMessageValidator; @@ -21,22 +22,22 @@ class PrivateMessageValidator extends BdfMessageValidator { } @Override - protected BdfDictionary validateMessage(BdfList message, Group g, - long timestamp) throws FormatException { + protected BdfDictionary validateMessage(Message m, Group g, + BdfList body) throws FormatException { // Parent ID, content type, private message body - checkSize(message, 3); + checkSize(body, 3); // Parent ID is optional - byte[] parentId = message.getOptionalRaw(0); + byte[] parentId = body.getOptionalRaw(0); checkLength(parentId, UniqueId.LENGTH); // Content type - String contentType = message.getString(1); + String contentType = body.getString(1); checkLength(contentType, 0, MAX_CONTENT_TYPE_LENGTH); // Private message body - byte[] body = message.getRaw(2); - checkLength(body, 0, MAX_PRIVATE_MESSAGE_BODY_LENGTH); + byte[] privateMessageBody = body.getRaw(2); + checkLength(privateMessageBody, 0, MAX_PRIVATE_MESSAGE_BODY_LENGTH); // Return the metadata BdfDictionary meta = new BdfDictionary(); - meta.put("timestamp", timestamp); + meta.put("timestamp", m.getTimestamp()); if (parentId != null) meta.put("parent", parentId); meta.put("contentType", contentType); meta.put("local", false); diff --git a/briar-core/src/org/briarproject/properties/TransportPropertyValidator.java b/briar-core/src/org/briarproject/properties/TransportPropertyValidator.java index e6fcd69a6c6c96c9eb87033f2c882962ba6ea99b..55b913c7dd2f60aa935773158602a0fb27cf65a5 100644 --- a/briar-core/src/org/briarproject/properties/TransportPropertyValidator.java +++ b/briar-core/src/org/briarproject/properties/TransportPropertyValidator.java @@ -7,6 +7,7 @@ import org.briarproject.api.data.BdfDictionary; import org.briarproject.api.data.BdfList; import org.briarproject.api.data.MetadataEncoder; import org.briarproject.api.sync.Group; +import org.briarproject.api.sync.Message; import org.briarproject.api.system.Clock; import org.briarproject.clients.BdfMessageValidator; @@ -22,21 +23,21 @@ class TransportPropertyValidator extends BdfMessageValidator { } @Override - protected BdfDictionary validateMessage(BdfList message, Group g, - long timestamp) throws FormatException { + protected BdfDictionary validateMessage(Message m, Group g, + BdfList body) throws FormatException { // Device ID, transport ID, version, properties - checkSize(message, 4); + checkSize(body, 4); // Device ID - byte[] deviceId = message.getRaw(0); + byte[] deviceId = body.getRaw(0); checkLength(deviceId, UniqueId.LENGTH); // Transport ID - String transportId = message.getString(1); + String transportId = body.getString(1); checkLength(transportId, 1, MAX_TRANSPORT_ID_LENGTH); // Version - long version = message.getLong(2); + long version = body.getLong(2); if (version < 0) throw new FormatException(); // Properties - BdfDictionary dictionary = message.getDictionary(3); + BdfDictionary dictionary = body.getDictionary(3); checkSize(dictionary, 0, MAX_PROPERTIES_PER_TRANSPORT); for (String key : dictionary.keySet()) { checkLength(key, 0, MAX_PROPERTY_LENGTH); diff --git a/briar-core/src/org/briarproject/sync/ValidationManagerImpl.java b/briar-core/src/org/briarproject/sync/ValidationManagerImpl.java index 248602434ef9cadf46cab96105777a4630c13dad..949dcc516ee76cccb7ea85d2c77388fa8a59e320 100644 --- a/briar-core/src/org/briarproject/sync/ValidationManagerImpl.java +++ b/briar-core/src/org/briarproject/sync/ValidationManagerImpl.java @@ -20,16 +20,13 @@ import org.briarproject.api.sync.Group; import org.briarproject.api.sync.GroupId; import org.briarproject.api.sync.Message; import org.briarproject.api.sync.MessageId; -import org.briarproject.api.sync.MessageValidator; import org.briarproject.api.sync.ValidationManager; import org.briarproject.util.ByteUtils; import java.util.LinkedList; -import java.util.List; import java.util.Map; import java.util.Queue; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.Executor; import java.util.logging.Logger; @@ -46,7 +43,7 @@ class ValidationManagerImpl implements ValidationManager, Service, private final Executor dbExecutor; private final Executor cryptoExecutor; private final Map<ClientId, MessageValidator> validators; - private final List<ValidationHook> hooks; + private final Map<ClientId, IncomingMessageHook> hooks; @Inject ValidationManagerImpl(DatabaseComponent db, @@ -56,7 +53,7 @@ class ValidationManagerImpl implements ValidationManager, Service, this.dbExecutor = dbExecutor; this.cryptoExecutor = cryptoExecutor; validators = new ConcurrentHashMap<ClientId, MessageValidator>(); - hooks = new CopyOnWriteArrayList<ValidationHook>(); + hooks = new ConcurrentHashMap<ClientId, IncomingMessageHook>(); } @Override @@ -76,8 +73,9 @@ class ValidationManagerImpl implements ValidationManager, Service, } @Override - public void registerValidationHook(ValidationHook hook) { - hooks.add(hook); + public void registerIncomingMessageHook(ClientId c, + IncomingMessageHook hook) { + hooks.put(c, hook); } private void getMessagesToValidate(final ClientId c) { @@ -170,8 +168,9 @@ class ValidationManagerImpl implements ValidationManager, Service, db.mergeMessageMetadata(txn, m.getId(), meta); db.setMessageValid(txn, m, c, true); db.setMessageShared(txn, m, true); - for (ValidationHook hook : hooks) - hook.validatingMessage(txn, m, c, meta); + IncomingMessageHook hook = hooks.get(c); + if (hook != null) + hook.incomingMessage(txn, m, meta); } txn.setComplete(); } finally { diff --git a/briar-tests/src/org/briarproject/clients/MessageQueueManagerImplTest.java b/briar-tests/src/org/briarproject/clients/MessageQueueManagerImplTest.java new file mode 100644 index 0000000000000000000000000000000000000000..3cbed06cb9c4ab6835d62355ca739d1cf54bfdff --- /dev/null +++ b/briar-tests/src/org/briarproject/clients/MessageQueueManagerImplTest.java @@ -0,0 +1,572 @@ +package org.briarproject.clients; + +import org.briarproject.BriarTestCase; +import org.briarproject.TestUtils; +import org.briarproject.api.clients.ClientHelper; +import org.briarproject.api.clients.MessageQueueManager.IncomingQueueMessageHook; +import org.briarproject.api.clients.MessageQueueManager.QueueMessageValidator; +import org.briarproject.api.clients.QueueMessage; +import org.briarproject.api.clients.QueueMessageFactory; +import org.briarproject.api.data.BdfDictionary; +import org.briarproject.api.data.BdfList; +import org.briarproject.api.db.DatabaseComponent; +import org.briarproject.api.db.Metadata; +import org.briarproject.api.db.Transaction; +import org.briarproject.api.sync.ClientId; +import org.briarproject.api.sync.Group; +import org.briarproject.api.sync.GroupId; +import org.briarproject.api.sync.Message; +import org.briarproject.api.sync.MessageId; +import org.briarproject.api.sync.ValidationManager; +import org.briarproject.api.sync.ValidationManager.IncomingMessageHook; +import org.briarproject.api.sync.ValidationManager.MessageValidator; +import org.briarproject.util.ByteUtils; +import org.hamcrest.Description; +import org.jmock.Expectations; +import org.jmock.Mockery; +import org.jmock.api.Action; +import org.jmock.api.Invocation; +import org.junit.Test; + +import java.util.concurrent.atomic.AtomicReference; + +import static org.briarproject.api.clients.MessageQueueManager.QUEUE_STATE_KEY; +import static org.briarproject.api.clients.QueueMessage.QUEUE_MESSAGE_HEADER_LENGTH; +import static org.briarproject.api.sync.SyncConstants.MAX_GROUP_DESCRIPTOR_LENGTH; +import static org.briarproject.api.sync.SyncConstants.MESSAGE_HEADER_LENGTH; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +public class MessageQueueManagerImplTest extends BriarTestCase { + + private final GroupId groupId = new GroupId(TestUtils.getRandomId()); + private final ClientId clientId = new ClientId(TestUtils.getRandomId()); + private final byte[] descriptor = new byte[MAX_GROUP_DESCRIPTOR_LENGTH]; + private final Group group = new Group(groupId, clientId, descriptor); + private final long timestamp = System.currentTimeMillis(); + + @Test + public void testSendingMessages() throws Exception { + Mockery context = new Mockery(); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final ClientHelper clientHelper = context.mock(ClientHelper.class); + final QueueMessageFactory queueMessageFactory = + context.mock(QueueMessageFactory.class); + final ValidationManager validationManager = + context.mock(ValidationManager.class); + final Transaction txn = new Transaction(null); + final byte[] body = new byte[123]; + final Metadata groupMetadata = new Metadata(); + final Metadata messageMetadata = new Metadata(); + final Metadata groupMetadata1 = new Metadata(); + final byte[] queueState = new byte[123]; + groupMetadata1.put(QUEUE_STATE_KEY, queueState); + context.checking(new Expectations() {{ + // First message: queue state does not exist + oneOf(db).getGroupMetadata(txn, groupId); + will(returnValue(groupMetadata)); + oneOf(clientHelper).toByteArray(with(any(BdfDictionary.class))); + will(new EncodeQueueStateAction(1L, 0L, new BdfList())); + oneOf(db).mergeGroupMetadata(with(txn), with(groupId), + with(any(Metadata.class))); + oneOf(queueMessageFactory).createMessage(groupId, timestamp, 0L, + body); + will(new CreateMessageAction()); + oneOf(db).addLocalMessage(with(txn), with(any(QueueMessage.class)), + with(clientId), with(messageMetadata), with(true)); + // Second message: queue state exists + oneOf(db).getGroupMetadata(txn, groupId); + will(returnValue(groupMetadata1)); + oneOf(clientHelper).toDictionary(queueState, 0, queueState.length); + will(new DecodeQueueStateAction(1L, 0L, new BdfList())); + oneOf(clientHelper).toByteArray(with(any(BdfDictionary.class))); + will(new EncodeQueueStateAction(2L, 0L, new BdfList())); + oneOf(db).mergeGroupMetadata(with(txn), with(groupId), + with(any(Metadata.class))); + oneOf(queueMessageFactory).createMessage(groupId, timestamp, 1L, + body); + will(new CreateMessageAction()); + oneOf(db).addLocalMessage(with(txn), with(any(QueueMessage.class)), + with(clientId), with(messageMetadata), with(true)); + }}); + + MessageQueueManagerImpl mqm = new MessageQueueManagerImpl(db, + clientHelper, queueMessageFactory, validationManager); + + // First message + QueueMessage q = mqm.sendMessage(txn, group, timestamp, body, + messageMetadata); + assertEquals(groupId, q.getGroupId()); + assertEquals(timestamp, q.getTimestamp()); + assertEquals(0L, q.getQueuePosition()); + assertEquals(QUEUE_MESSAGE_HEADER_LENGTH + body.length, q.getLength()); + + // Second message + QueueMessage q1 = mqm.sendMessage(txn, group, timestamp, body, + messageMetadata); + assertEquals(groupId, q1.getGroupId()); + assertEquals(timestamp, q1.getTimestamp()); + assertEquals(1L, q1.getQueuePosition()); + assertEquals(QUEUE_MESSAGE_HEADER_LENGTH + body.length, q1.getLength()); + + context.assertIsSatisfied(); + } + + @Test + public void testValidatorRejectsShortMessage() throws Exception { + Mockery context = new Mockery(); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final ClientHelper clientHelper = context.mock(ClientHelper.class); + final QueueMessageFactory queueMessageFactory = + context.mock(QueueMessageFactory.class); + final ValidationManager validationManager = + context.mock(ValidationManager.class); + final AtomicReference<MessageValidator> captured = + new AtomicReference<MessageValidator>(); + final QueueMessageValidator queueMessageValidator = + context.mock(QueueMessageValidator.class); + // The message is too short to be a valid queue message + final MessageId messageId = new MessageId(TestUtils.getRandomId()); + final byte[] raw = new byte[QUEUE_MESSAGE_HEADER_LENGTH - 1]; + final Message message = new Message(messageId, groupId, timestamp, raw); + context.checking(new Expectations() {{ + oneOf(validationManager).registerMessageValidator(with(clientId), + with(any(MessageValidator.class))); + will(new CaptureArgumentAction<MessageValidator>(captured, + MessageValidator.class, 1)); + }}); + + + MessageQueueManagerImpl mqm = new MessageQueueManagerImpl(db, + clientHelper, queueMessageFactory, validationManager); + + // Capture the delegating message validator + mqm.registerMessageValidator(clientId, queueMessageValidator); + MessageValidator delegate = captured.get(); + assertNotNull(delegate); + // The message should be invalid + assertNull(delegate.validateMessage(message, group)); + + context.assertIsSatisfied(); + } + + @Test + public void testValidatorRejectsNegativeQueuePosition() throws Exception { + Mockery context = new Mockery(); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final ClientHelper clientHelper = context.mock(ClientHelper.class); + final QueueMessageFactory queueMessageFactory = + context.mock(QueueMessageFactory.class); + final ValidationManager validationManager = + context.mock(ValidationManager.class); + final AtomicReference<MessageValidator> captured = + new AtomicReference<MessageValidator>(); + final QueueMessageValidator queueMessageValidator = + context.mock(QueueMessageValidator.class); + // The message has a negative queue position + final MessageId messageId = new MessageId(TestUtils.getRandomId()); + final byte[] raw = new byte[QUEUE_MESSAGE_HEADER_LENGTH]; + for (int i = 0; i < 8; i++) + raw[MESSAGE_HEADER_LENGTH + i] = (byte) 0xFF; + final Message message = new Message(messageId, groupId, timestamp, raw); + context.checking(new Expectations() {{ + oneOf(validationManager).registerMessageValidator(with(clientId), + with(any(MessageValidator.class))); + will(new CaptureArgumentAction<MessageValidator>(captured, + MessageValidator.class, 1)); + }}); + + + MessageQueueManagerImpl mqm = new MessageQueueManagerImpl(db, + clientHelper, queueMessageFactory, validationManager); + + // Capture the delegating message validator + mqm.registerMessageValidator(clientId, queueMessageValidator); + MessageValidator delegate = captured.get(); + assertNotNull(delegate); + // The message should be invalid + assertNull(delegate.validateMessage(message, group)); + + context.assertIsSatisfied(); + } + + @Test + public void testValidatorDelegatesValidMessage() throws Exception { + Mockery context = new Mockery(); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final ClientHelper clientHelper = context.mock(ClientHelper.class); + final QueueMessageFactory queueMessageFactory = + context.mock(QueueMessageFactory.class); + final ValidationManager validationManager = + context.mock(ValidationManager.class); + final AtomicReference<MessageValidator> captured = + new AtomicReference<MessageValidator>(); + final QueueMessageValidator queueMessageValidator = + context.mock(QueueMessageValidator.class); + final Metadata messageMetadata = new Metadata(); + // The message is valid, with a queue position of zero + final MessageId messageId = new MessageId(TestUtils.getRandomId()); + final byte[] raw = new byte[QUEUE_MESSAGE_HEADER_LENGTH]; + final Message message = new Message(messageId, groupId, timestamp, raw); + context.checking(new Expectations() {{ + oneOf(validationManager).registerMessageValidator(with(clientId), + with(any(MessageValidator.class))); + will(new CaptureArgumentAction<MessageValidator>(captured, + MessageValidator.class, 1)); + // The message should be delegated + oneOf(queueMessageValidator).validateMessage( + with(any(QueueMessage.class)), with(group)); + will(returnValue(messageMetadata)); + }}); + + + MessageQueueManagerImpl mqm = new MessageQueueManagerImpl(db, + clientHelper, queueMessageFactory, validationManager); + + // Capture the delegating message validator + mqm.registerMessageValidator(clientId, queueMessageValidator); + MessageValidator delegate = captured.get(); + assertNotNull(delegate); + // The message should be valid and the metadata should be returned + assertSame(messageMetadata, delegate.validateMessage(message, group)); + + context.assertIsSatisfied(); + } + + @Test + public void testIncomingMessageHookDeletesDuplicateMessage() + throws Exception { + Mockery context = new Mockery(); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final ClientHelper clientHelper = context.mock(ClientHelper.class); + final QueueMessageFactory queueMessageFactory = + context.mock(QueueMessageFactory.class); + final ValidationManager validationManager = + context.mock(ValidationManager.class); + final AtomicReference<IncomingMessageHook> captured = + new AtomicReference<IncomingMessageHook>(); + final IncomingQueueMessageHook incomingQueueMessageHook = + context.mock(IncomingQueueMessageHook.class); + final Transaction txn = new Transaction(null); + final Metadata groupMetadata = new Metadata(); + final byte[] queueState = new byte[123]; + groupMetadata.put(QUEUE_STATE_KEY, queueState); + // The message has queue position 0 + final MessageId messageId = new MessageId(TestUtils.getRandomId()); + final byte[] raw = new byte[QUEUE_MESSAGE_HEADER_LENGTH]; + final Message message = new Message(messageId, groupId, timestamp, raw); + context.checking(new Expectations() {{ + oneOf(validationManager).registerIncomingMessageHook(with(clientId), + with(any(IncomingMessageHook.class))); + will(new CaptureArgumentAction<IncomingMessageHook>(captured, + IncomingMessageHook.class, 1)); + oneOf(db).getGroupMetadata(txn, groupId); + will(returnValue(groupMetadata)); + // Queue position 1 is expected + oneOf(clientHelper).toDictionary(queueState, 0, queueState.length); + will(new DecodeQueueStateAction(0L, 1L, new BdfList())); + // The message and its metadata should be deleted + oneOf(db).deleteMessage(txn, messageId); + oneOf(db).deleteMessageMetadata(txn, messageId); + }}); + + + MessageQueueManagerImpl mqm = new MessageQueueManagerImpl(db, + clientHelper, queueMessageFactory, validationManager); + + // Capture the delegating incoming message hook + mqm.registerIncomingMessageHook(clientId, incomingQueueMessageHook); + IncomingMessageHook delegate = captured.get(); + assertNotNull(delegate); + // Pass the message to the hook + delegate.incomingMessage(txn, message, new Metadata()); + + context.assertIsSatisfied(); + } + + @Test + public void testIncomingMessageHookAddsOutOfOrderMessageToPendingList() + throws Exception { + Mockery context = new Mockery(); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final ClientHelper clientHelper = context.mock(ClientHelper.class); + final QueueMessageFactory queueMessageFactory = + context.mock(QueueMessageFactory.class); + final ValidationManager validationManager = + context.mock(ValidationManager.class); + final AtomicReference<IncomingMessageHook> captured = + new AtomicReference<IncomingMessageHook>(); + final IncomingQueueMessageHook incomingQueueMessageHook = + context.mock(IncomingQueueMessageHook.class); + final Transaction txn = new Transaction(null); + final Metadata groupMetadata = new Metadata(); + final byte[] queueState = new byte[123]; + groupMetadata.put(QUEUE_STATE_KEY, queueState); + // The message has queue position 1 + final MessageId messageId = new MessageId(TestUtils.getRandomId()); + final byte[] raw = new byte[QUEUE_MESSAGE_HEADER_LENGTH]; + ByteUtils.writeUint64(1L, raw, MESSAGE_HEADER_LENGTH); + final Message message = new Message(messageId, groupId, timestamp, raw); + final BdfList pending = BdfList.of(BdfList.of(1L, messageId)); + context.checking(new Expectations() {{ + oneOf(validationManager).registerIncomingMessageHook(with(clientId), + with(any(IncomingMessageHook.class))); + will(new CaptureArgumentAction<IncomingMessageHook>(captured, + IncomingMessageHook.class, 1)); + oneOf(db).getGroupMetadata(txn, groupId); + will(returnValue(groupMetadata)); + // Queue position 0 is expected + oneOf(clientHelper).toDictionary(queueState, 0, queueState.length); + will(new DecodeQueueStateAction(0L, 0L, new BdfList())); + // The message should be added to the pending list + oneOf(clientHelper).toByteArray(with(any(BdfDictionary.class))); + will(new EncodeQueueStateAction(0L, 0L, pending)); + oneOf(db).mergeGroupMetadata(with(txn), with(groupId), + with(any(Metadata.class))); + }}); + + + MessageQueueManagerImpl mqm = new MessageQueueManagerImpl(db, + clientHelper, queueMessageFactory, validationManager); + + // Capture the delegating incoming message hook + mqm.registerIncomingMessageHook(clientId, incomingQueueMessageHook); + IncomingMessageHook delegate = captured.get(); + assertNotNull(delegate); + // Pass the message to the hook + delegate.incomingMessage(txn, message, new Metadata()); + + context.assertIsSatisfied(); + } + + @Test + public void testIncomingMessageHookDelegatesInOrderMessage() + throws Exception { + Mockery context = new Mockery(); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final ClientHelper clientHelper = context.mock(ClientHelper.class); + final QueueMessageFactory queueMessageFactory = + context.mock(QueueMessageFactory.class); + final ValidationManager validationManager = + context.mock(ValidationManager.class); + final AtomicReference<IncomingMessageHook> captured = + new AtomicReference<IncomingMessageHook>(); + final IncomingQueueMessageHook incomingQueueMessageHook = + context.mock(IncomingQueueMessageHook.class); + final Transaction txn = new Transaction(null); + final Metadata groupMetadata = new Metadata(); + final byte[] queueState = new byte[123]; + groupMetadata.put(QUEUE_STATE_KEY, queueState); + // The message has queue position 0 + final MessageId messageId = new MessageId(TestUtils.getRandomId()); + final byte[] raw = new byte[QUEUE_MESSAGE_HEADER_LENGTH]; + final Message message = new Message(messageId, groupId, timestamp, raw); + final Metadata messageMetadata = new Metadata(); + context.checking(new Expectations() {{ + oneOf(validationManager).registerIncomingMessageHook(with(clientId), + with(any(IncomingMessageHook.class))); + will(new CaptureArgumentAction<IncomingMessageHook>(captured, + IncomingMessageHook.class, 1)); + oneOf(db).getGroupMetadata(txn, groupId); + will(returnValue(groupMetadata)); + // Queue position 0 is expected + oneOf(clientHelper).toDictionary(queueState, 0, queueState.length); + will(new DecodeQueueStateAction(0L, 0L, new BdfList())); + // The message should be delegated + oneOf(incomingQueueMessageHook).incomingMessage(with(txn), + with(any(QueueMessage.class)), with(messageMetadata)); + // Queue position 1 should be expected next + oneOf(clientHelper).toByteArray(with(any(BdfDictionary.class))); + will(new EncodeQueueStateAction(0L, 1L, new BdfList())); + oneOf(db).mergeGroupMetadata(with(txn), with(groupId), + with(any(Metadata.class))); + }}); + + + MessageQueueManagerImpl mqm = new MessageQueueManagerImpl(db, + clientHelper, queueMessageFactory, validationManager); + + // Capture the delegating incoming message hook + mqm.registerIncomingMessageHook(clientId, incomingQueueMessageHook); + IncomingMessageHook delegate = captured.get(); + assertNotNull(delegate); + // Pass the message to the hook + delegate.incomingMessage(txn, message, messageMetadata); + + context.assertIsSatisfied(); + } + + @Test + public void testIncomingMessageHookRetrievesPendingMessage() + throws Exception { + Mockery context = new Mockery(); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final ClientHelper clientHelper = context.mock(ClientHelper.class); + final QueueMessageFactory queueMessageFactory = + context.mock(QueueMessageFactory.class); + final ValidationManager validationManager = + context.mock(ValidationManager.class); + final AtomicReference<IncomingMessageHook> captured = + new AtomicReference<IncomingMessageHook>(); + final IncomingQueueMessageHook incomingQueueMessageHook = + context.mock(IncomingQueueMessageHook.class); + final Transaction txn = new Transaction(null); + final Metadata groupMetadata = new Metadata(); + final byte[] queueState = new byte[123]; + groupMetadata.put(QUEUE_STATE_KEY, queueState); + // The message has queue position 0 + final MessageId messageId = new MessageId(TestUtils.getRandomId()); + final byte[] raw = new byte[QUEUE_MESSAGE_HEADER_LENGTH]; + final Message message = new Message(messageId, groupId, timestamp, raw); + final Metadata messageMetadata = new Metadata(); + // Queue position 1 is pending + final MessageId messageId1 = new MessageId(TestUtils.getRandomId()); + final byte[] raw1 = new byte[QUEUE_MESSAGE_HEADER_LENGTH]; + final QueueMessage message1 = new QueueMessage(messageId1, groupId, + timestamp, 1L, raw1); + final Metadata messageMetadata1 = new Metadata(); + final BdfList pending = BdfList.of(BdfList.of(1L, messageId1)); + context.checking(new Expectations() {{ + oneOf(validationManager).registerIncomingMessageHook(with(clientId), + with(any(IncomingMessageHook.class))); + will(new CaptureArgumentAction<IncomingMessageHook>(captured, + IncomingMessageHook.class, 1)); + oneOf(db).getGroupMetadata(txn, groupId); + will(returnValue(groupMetadata)); + // Queue position 0 is expected, position 1 is pending + oneOf(clientHelper).toDictionary(queueState, 0, queueState.length); + will(new DecodeQueueStateAction(0L, 0L, pending)); + // The message should be delegated + oneOf(incomingQueueMessageHook).incomingMessage(with(txn), + with(any(QueueMessage.class)), with(messageMetadata)); + // The pending message should be retrieved + oneOf(db).getRawMessage(txn, messageId1); + will(returnValue(raw1)); + oneOf(db).getMessageMetadata(txn, messageId1); + will(returnValue(messageMetadata1)); + oneOf(queueMessageFactory).createMessage(messageId1, raw1); + will(returnValue(message1)); + // The pending message should be delegated + oneOf(incomingQueueMessageHook).incomingMessage(txn, message1, + messageMetadata1); + // Queue position 2 should be expected next + oneOf(clientHelper).toByteArray(with(any(BdfDictionary.class))); + will(new EncodeQueueStateAction(0L, 2L, new BdfList())); + oneOf(db).mergeGroupMetadata(with(txn), with(groupId), + with(any(Metadata.class))); + }}); + + + MessageQueueManagerImpl mqm = new MessageQueueManagerImpl(db, + clientHelper, queueMessageFactory, validationManager); + + // Capture the delegating incoming message hook + mqm.registerIncomingMessageHook(clientId, incomingQueueMessageHook); + IncomingMessageHook delegate = captured.get(); + assertNotNull(delegate); + // Pass the message to the hook + delegate.incomingMessage(txn, message, messageMetadata); + + context.assertIsSatisfied(); + } + + private class EncodeQueueStateAction implements Action { + + private final long outgoingPosition, incomingPosition; + private final BdfList pending; + + private EncodeQueueStateAction(long outgoingPosition, + long incomingPosition, BdfList pending) { + this.outgoingPosition = outgoingPosition; + this.incomingPosition = incomingPosition; + this.pending = pending; + } + + @Override + public Object invoke(Invocation invocation) throws Throwable { + BdfDictionary d = (BdfDictionary) invocation.getParameter(0); + assertEquals(outgoingPosition, d.getLong("nextOut").longValue()); + assertEquals(incomingPosition, d.getLong("nextIn").longValue()); + assertEquals(pending, d.getList("pending")); + return new byte[123]; + } + + @Override + public void describeTo(Description description) { + description.appendText("encodes a queue state"); + } + } + + private class DecodeQueueStateAction implements Action { + + private final long outgoingPosition, incomingPosition; + private final BdfList pending; + + private DecodeQueueStateAction(long outgoingPosition, + long incomingPosition, BdfList pending) { + this.outgoingPosition = outgoingPosition; + this.incomingPosition = incomingPosition; + this.pending = pending; + } + + @Override + public Object invoke(Invocation invocation) throws Throwable { + BdfDictionary d = new BdfDictionary(); + d.put("nextOut", outgoingPosition); + d.put("nextIn", incomingPosition); + d.put("pending", pending); + return d; + } + + @Override + public void describeTo(Description description) { + description.appendText("decodes a queue state"); + } + } + + private class CreateMessageAction implements Action { + + @Override + public Object invoke(Invocation invocation) throws Throwable { + GroupId groupId = (GroupId) invocation.getParameter(0); + long timestamp = (Long) invocation.getParameter(1); + long queuePosition = (Long) invocation.getParameter(2); + byte[] body = (byte[]) invocation.getParameter(3); + byte[] raw = new byte[QUEUE_MESSAGE_HEADER_LENGTH + body.length]; + MessageId id = new MessageId(TestUtils.getRandomId()); + return new QueueMessage(id, groupId, timestamp, queuePosition, raw); + } + + @Override + public void describeTo(Description description) { + description.appendText("creates a message"); + } + } + + private class CaptureArgumentAction<T> implements Action { + + private final AtomicReference<T> captured; + private final Class<T> capturedClass; + private final int index; + + private CaptureArgumentAction(AtomicReference<T> captured, + Class<T> capturedClass, int index) { + this.captured = captured; + this.capturedClass = capturedClass; + this.index = index; + } + + @Override + public Object invoke(Invocation invocation) throws Throwable { + captured.set(capturedClass.cast(invocation.getParameter(index))); + return null; + } + + @Override + public void describeTo(Description description) { + description.appendText("captures an argument"); + } + } +} diff --git a/briar-tests/src/org/briarproject/sync/ValidationManagerImplTest.java b/briar-tests/src/org/briarproject/sync/ValidationManagerImplTest.java index b55bacd90694245b784f41b51f083a31c4186516..d883a1ef20f746256b698972bd0a450cbf837ca3 100644 --- a/briar-tests/src/org/briarproject/sync/ValidationManagerImplTest.java +++ b/briar-tests/src/org/briarproject/sync/ValidationManagerImplTest.java @@ -16,8 +16,8 @@ import org.briarproject.api.sync.Group; import org.briarproject.api.sync.GroupId; import org.briarproject.api.sync.Message; import org.briarproject.api.sync.MessageId; -import org.briarproject.api.sync.MessageValidator; -import org.briarproject.api.sync.ValidationManager.ValidationHook; +import org.briarproject.api.sync.ValidationManager.IncomingMessageHook; +import org.briarproject.api.sync.ValidationManager.MessageValidator; import org.briarproject.util.ByteUtils; import org.jmock.Expectations; import org.jmock.Mockery; @@ -56,7 +56,8 @@ public class ValidationManagerImplTest extends BriarTestCase { final Executor dbExecutor = new ImmediateExecutor(); final Executor cryptoExecutor = new ImmediateExecutor(); final MessageValidator validator = context.mock(MessageValidator.class); - final ValidationHook hook = context.mock(ValidationHook.class); + final IncomingMessageHook hook = + context.mock(IncomingMessageHook.class); final Transaction txn = new Transaction(null); final Transaction txn1 = new Transaction(null); final Transaction txn2 = new Transaction(null); @@ -87,7 +88,7 @@ public class ValidationManagerImplTest extends BriarTestCase { oneOf(db).setMessageValid(txn2, message, clientId, true); oneOf(db).setMessageShared(txn2, message, true); // Call the hook for the first message - oneOf(hook).validatingMessage(txn2, message, clientId, metadata); + oneOf(hook).incomingMessage(txn2, message, metadata); oneOf(db).endTransaction(txn2); // Load the second raw message and group oneOf(db).startTransaction(); @@ -110,7 +111,7 @@ public class ValidationManagerImplTest extends BriarTestCase { ValidationManagerImpl vm = new ValidationManagerImpl(db, dbExecutor, cryptoExecutor); vm.registerMessageValidator(clientId, validator); - vm.registerValidationHook(hook); + vm.registerIncomingMessageHook(clientId, hook); vm.start(); context.assertIsSatisfied(); @@ -124,7 +125,8 @@ public class ValidationManagerImplTest extends BriarTestCase { final Executor dbExecutor = new ImmediateExecutor(); final Executor cryptoExecutor = new ImmediateExecutor(); final MessageValidator validator = context.mock(MessageValidator.class); - final ValidationHook hook = context.mock(ValidationHook.class); + final IncomingMessageHook hook = + context.mock(IncomingMessageHook.class); final Transaction txn = new Transaction(null); final Transaction txn1 = new Transaction(null); final Transaction txn2 = new Transaction(null); @@ -163,7 +165,7 @@ public class ValidationManagerImplTest extends BriarTestCase { ValidationManagerImpl vm = new ValidationManagerImpl(db, dbExecutor, cryptoExecutor); vm.registerMessageValidator(clientId, validator); - vm.registerValidationHook(hook); + vm.registerIncomingMessageHook(clientId, hook); vm.start(); context.assertIsSatisfied(); @@ -177,7 +179,8 @@ public class ValidationManagerImplTest extends BriarTestCase { final Executor dbExecutor = new ImmediateExecutor(); final Executor cryptoExecutor = new ImmediateExecutor(); final MessageValidator validator = context.mock(MessageValidator.class); - final ValidationHook hook = context.mock(ValidationHook.class); + final IncomingMessageHook hook = + context.mock(IncomingMessageHook.class); final Transaction txn = new Transaction(null); final Transaction txn1 = new Transaction(null); final Transaction txn2 = new Transaction(null); @@ -219,7 +222,7 @@ public class ValidationManagerImplTest extends BriarTestCase { ValidationManagerImpl vm = new ValidationManagerImpl(db, dbExecutor, cryptoExecutor); vm.registerMessageValidator(clientId, validator); - vm.registerValidationHook(hook); + vm.registerIncomingMessageHook(clientId, hook); vm.start(); context.assertIsSatisfied(); @@ -232,7 +235,8 @@ public class ValidationManagerImplTest extends BriarTestCase { final Executor dbExecutor = new ImmediateExecutor(); final Executor cryptoExecutor = new ImmediateExecutor(); final MessageValidator validator = context.mock(MessageValidator.class); - final ValidationHook hook = context.mock(ValidationHook.class); + final IncomingMessageHook hook = + context.mock(IncomingMessageHook.class); final Transaction txn = new Transaction(null); final Transaction txn1 = new Transaction(null); context.checking(new Expectations() {{ @@ -252,14 +256,14 @@ public class ValidationManagerImplTest extends BriarTestCase { oneOf(db).setMessageValid(txn1, message, clientId, true); oneOf(db).setMessageShared(txn1, message, true); // Call the hook - oneOf(hook).validatingMessage(txn1, message, clientId, metadata); + oneOf(hook).incomingMessage(txn1, message, metadata); oneOf(db).endTransaction(txn1); }}); ValidationManagerImpl vm = new ValidationManagerImpl(db, dbExecutor, cryptoExecutor); vm.registerMessageValidator(clientId, validator); - vm.registerValidationHook(hook); + vm.registerIncomingMessageHook(clientId, hook); vm.eventOccurred(new MessageAddedEvent(message, contactId)); context.assertIsSatisfied(); @@ -272,12 +276,13 @@ public class ValidationManagerImplTest extends BriarTestCase { final Executor dbExecutor = new ImmediateExecutor(); final Executor cryptoExecutor = new ImmediateExecutor(); final MessageValidator validator = context.mock(MessageValidator.class); - final ValidationHook hook = context.mock(ValidationHook.class); + final IncomingMessageHook hook = + context.mock(IncomingMessageHook.class); ValidationManagerImpl vm = new ValidationManagerImpl(db, dbExecutor, cryptoExecutor); vm.registerMessageValidator(clientId, validator); - vm.registerValidationHook(hook); + vm.registerIncomingMessageHook(clientId, hook); vm.eventOccurred(new MessageAddedEvent(message, null)); context.assertIsSatisfied();