diff --git a/api/net/sf/briar/api/db/DatabaseComponent.java b/api/net/sf/briar/api/db/DatabaseComponent.java index 2f47365caa62ac4c2c75044943f65117f708220e..85b35016d697acc5561b74221d1462b0470cae41 100644 --- a/api/net/sf/briar/api/db/DatabaseComponent.java +++ b/api/net/sf/briar/api/db/DatabaseComponent.java @@ -17,17 +17,13 @@ import net.sf.briar.api.protocol.GroupId; import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.Offer; +import net.sf.briar.api.protocol.RawBatch; +import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportUpdate; -import net.sf.briar.api.protocol.writers.AckWriter; -import net.sf.briar.api.protocol.writers.BatchWriter; -import net.sf.briar.api.protocol.writers.OfferWriter; -import net.sf.briar.api.protocol.writers.RequestWriter; -import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter; -import net.sf.briar.api.protocol.writers.TransportUpdateWriter; import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionWindow; @@ -72,43 +68,46 @@ public interface DatabaseComponent { TransportIndex addTransport(TransportId t) throws DbException; /** - * Generates an acknowledgement for the given contact. - * @return True if any batch IDs were added to the acknowledgement. + * Generates an acknowledgement for the given contact. Returns null if + * there are no batches to acknowledge. */ - boolean generateAck(ContactId c, AckWriter a) throws DbException, - IOException; + Ack generateAck(ContactId c, int maxBatches) throws DbException; /** - * Generates a batch of messages for the given contact. - * @return True if any messages were added to tbe batch. + * Generates a batch of messages for the given contact. Returns null if + * there are no sendable messages that fit in the given capacity. */ - boolean generateBatch(ContactId c, BatchWriter b) throws DbException, - IOException; + RawBatch generateBatch(ContactId c, int capacity) throws DbException; /** * Generates a batch of messages for the given contact from the given * collection of requested messages. Any messages that were either added to * the batch, or were considered but are no longer sendable to the contact, * are removed from the collection of requested messages before returning. - * @return True if any messages were added to the batch. + * Returns null if there are no sendable messages that fit in the given + * capacity. */ - boolean generateBatch(ContactId c, BatchWriter b, - Collection<MessageId> requested) throws DbException, IOException; + RawBatch generateBatch(ContactId c, int capacity, + Collection<MessageId> requested) throws DbException; /** - * Generates an offer for the given contact and returns the offered - * message IDs. + * Generates an offer for the given contact. Returns null if there are no + * messages to offer. */ - Collection<MessageId> generateOffer(ContactId c, OfferWriter o) - throws DbException, IOException; + Offer generateOffer(ContactId c, int maxMessages) throws DbException; - /** Generates a subscription update for the given contact. */ - void generateSubscriptionUpdate(ContactId c, SubscriptionUpdateWriter s) - throws DbException, IOException; + /** + * Generates a subscription update for the given contact. Returns null if + * an update is not due. + */ + SubscriptionUpdate generateSubscriptionUpdate(ContactId c) + throws DbException; - /** Generates a transport update for the given contact. */ - void generateTransportUpdate(ContactId c, TransportUpdateWriter t) - throws DbException, IOException; + /** + * Generates a transport update for the given contact. Returns null if an + * update is not due. + */ + TransportUpdate generateTransportUpdate(ContactId c) throws DbException; /** Returns the configuration for the given transport. */ TransportConfig getConfig(TransportId t) throws DbException; @@ -185,8 +184,7 @@ public interface DatabaseComponent { * to the contact are requested just as though they were not present in the * database. */ - void receiveOffer(ContactId c, Offer o, RequestWriter r) throws DbException, - IOException; + Request receiveOffer(ContactId c, Offer o) throws DbException; /** Processes a subscription update from the given contact. */ void receiveSubscriptionUpdate(ContactId c, SubscriptionUpdate s) diff --git a/api/net/sf/briar/api/protocol/Batch.java b/api/net/sf/briar/api/protocol/Batch.java index eb46a8f7de53cd771955936bc4415b38626e009c..eec1fef477b332716916826ff9c842c6335bb708 100644 --- a/api/net/sf/briar/api/protocol/Batch.java +++ b/api/net/sf/briar/api/protocol/Batch.java @@ -2,7 +2,7 @@ package net.sf.briar.api.protocol; import java.util.Collection; -/** A packet containing messages. */ +/** An incoming packet containing messages. */ public interface Batch { /** Returns the batch's unique identifier. */ diff --git a/api/net/sf/briar/api/protocol/Message.java b/api/net/sf/briar/api/protocol/Message.java index e292bcab78c7b856375c6c5051419485b1206189..23b0e745ade49284dc5858773fcfce4f98d6cf38 100644 --- a/api/net/sf/briar/api/protocol/Message.java +++ b/api/net/sf/briar/api/protocol/Message.java @@ -23,9 +23,6 @@ public interface Message { /** Returns the timestamp created by the message's author. */ long getTimestamp(); - /** Returns the length of the serialised message in bytes. */ - int getLength(); - /** Returns the serialised message. */ byte[] getSerialised(); diff --git a/api/net/sf/briar/api/protocol/PacketFactory.java b/api/net/sf/briar/api/protocol/PacketFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..ab9413fc5be60b7c0aaebb17ff7118a0db8e4d58 --- /dev/null +++ b/api/net/sf/briar/api/protocol/PacketFactory.java @@ -0,0 +1,22 @@ +package net.sf.briar.api.protocol; + +import java.util.BitSet; +import java.util.Collection; +import java.util.Map; + +public interface PacketFactory { + + Ack createAck(Collection<BatchId> acked); + + RawBatch createBatch(Collection<byte[]> messages); + + Offer createOffer(Collection<MessageId> offered); + + Request createRequest(BitSet requested, int length); + + SubscriptionUpdate createSubscriptionUpdate(Map<Group, Long> subs, + long timestamp); + + TransportUpdate createTransportUpdate(Collection<Transport> transports, + long timestamp); +} diff --git a/api/net/sf/briar/api/protocol/ProtocolWriter.java b/api/net/sf/briar/api/protocol/ProtocolWriter.java new file mode 100644 index 0000000000000000000000000000000000000000..2e7c6aaf8cf26413809f9a5a5d55371412db38c6 --- /dev/null +++ b/api/net/sf/briar/api/protocol/ProtocolWriter.java @@ -0,0 +1,24 @@ +package net.sf.briar.api.protocol; + +import java.io.IOException; + +public interface ProtocolWriter { + + int getMaxBatchesForAck(long capacity); + + int getMaxMessagesForOffer(long capacity); + + int getMessageCapacityForBatch(long capacity); + + void writeAck(Ack a) throws IOException; + + void writeBatch(RawBatch b) throws IOException; + + void writeOffer(Offer o) throws IOException; + + void writeRequest(Request r) throws IOException; + + void writeSubscriptionUpdate(SubscriptionUpdate s) throws IOException; + + void writeTransportUpdate(TransportUpdate t) throws IOException; +} diff --git a/api/net/sf/briar/api/protocol/ProtocolWriterFactory.java b/api/net/sf/briar/api/protocol/ProtocolWriterFactory.java new file mode 100644 index 0000000000000000000000000000000000000000..fbed9d08577c3360d6f457d19eb941c3e7c57317 --- /dev/null +++ b/api/net/sf/briar/api/protocol/ProtocolWriterFactory.java @@ -0,0 +1,8 @@ +package net.sf.briar.api.protocol; + +import java.io.OutputStream; + +public interface ProtocolWriterFactory { + + ProtocolWriter createProtocolWriter(OutputStream out); +} diff --git a/api/net/sf/briar/api/protocol/RawBatch.java b/api/net/sf/briar/api/protocol/RawBatch.java new file mode 100644 index 0000000000000000000000000000000000000000..f1337f645813b6c1abb91d7f57b33911df7eb802 --- /dev/null +++ b/api/net/sf/briar/api/protocol/RawBatch.java @@ -0,0 +1,13 @@ +package net.sf.briar.api.protocol; + +import java.util.Collection; + +/** An outgoing packet containing messages. */ +public interface RawBatch { + + /** Returns the batch's unique identifier. */ + BatchId getId(); + + /** Returns the serialised messages contained in the batch. */ + Collection<byte[]> getMessages(); +} diff --git a/api/net/sf/briar/api/protocol/Request.java b/api/net/sf/briar/api/protocol/Request.java index a17023020748cc95dd8a0499ae5047d0d3e29e21..242e59bbd23d19a558c5d03350ad602ec17b954c 100644 --- a/api/net/sf/briar/api/protocol/Request.java +++ b/api/net/sf/briar/api/protocol/Request.java @@ -10,4 +10,7 @@ public interface Request { * the offer, where the i^th bit is set if the i^th message should be sent. */ BitSet getBitmap(); + + /** Returns the length of the bitmap in bits. */ + int getLength(); } diff --git a/api/net/sf/briar/api/protocol/Types.java b/api/net/sf/briar/api/protocol/Types.java index cc6ed4bc9da58360e233bfd534a074b4ae13242b..a3a5919d410068f7a782de0229ff4368d9baf1b5 100644 --- a/api/net/sf/briar/api/protocol/Types.java +++ b/api/net/sf/briar/api/protocol/Types.java @@ -3,6 +3,7 @@ package net.sf.briar.api.protocol; /** Struct identifiers for encoding and decoding protocol objects. */ public interface Types { + // FIXME: Batch ID, message ID don't need to be structs static final int ACK = 0; static final int AUTHOR = 1; static final int BATCH = 2; diff --git a/api/net/sf/briar/api/protocol/writers/AckWriter.java b/api/net/sf/briar/api/protocol/writers/AckWriter.java deleted file mode 100644 index 80fbd88681d7492fd0522e207f713e95f3bcbf40..0000000000000000000000000000000000000000 --- a/api/net/sf/briar/api/protocol/writers/AckWriter.java +++ /dev/null @@ -1,24 +0,0 @@ -package net.sf.briar.api.protocol.writers; - -import java.io.IOException; - -import net.sf.briar.api.protocol.BatchId; - -/** An interface for creating an ack packet. */ -public interface AckWriter { - - /** - * Sets the maximum length of the serialised ack. If this method is not - * called, the default is ProtocolConstants.MAX_PACKET_LENGTH; - */ - void setMaxPacketLength(int length); - - /** - * Attempts to add the given BatchId to the ack and returns true if it - * was added. - */ - boolean writeBatchId(BatchId b) throws IOException; - - /** Finishes writing the ack. */ - void finish() throws IOException; -} diff --git a/api/net/sf/briar/api/protocol/writers/BatchWriter.java b/api/net/sf/briar/api/protocol/writers/BatchWriter.java deleted file mode 100644 index b4806166879ace9abbe2e36bd33e718321a76c21..0000000000000000000000000000000000000000 --- a/api/net/sf/briar/api/protocol/writers/BatchWriter.java +++ /dev/null @@ -1,27 +0,0 @@ -package net.sf.briar.api.protocol.writers; - -import java.io.IOException; - -import net.sf.briar.api.protocol.BatchId; - -/** An interface for creating a batch packet. */ -public interface BatchWriter { - - /** Returns the capacity of the batch in bytes. */ - int getCapacity(); - - /** - * Sets the maximum length of the serialised batch; the default is - * ProtocolConstants.MAX_PACKET_LENGTH; - */ - void setMaxPacketLength(int length); - - /** - * Attempts to add the given raw message to the batch and returns true if - * it was added. - */ - boolean writeMessage(byte[] raw) throws IOException; - - /** Finishes writing the batch and returns its unique identifier. */ - BatchId finish() throws IOException; -} diff --git a/api/net/sf/briar/api/protocol/writers/OfferWriter.java b/api/net/sf/briar/api/protocol/writers/OfferWriter.java deleted file mode 100644 index 35bdafe926ca6230db171c683b1b087efc08bf3e..0000000000000000000000000000000000000000 --- a/api/net/sf/briar/api/protocol/writers/OfferWriter.java +++ /dev/null @@ -1,24 +0,0 @@ -package net.sf.briar.api.protocol.writers; - -import java.io.IOException; - -import net.sf.briar.api.protocol.MessageId; - -/** An interface for creating an offer packet. */ -public interface OfferWriter { - - /** - * Sets the maximum length of the serialised offer. If this method is not - * called, the default is ProtocolConstants.MAX_PACKET_LENGTH; - */ - void setMaxPacketLength(int length); - - /** - * Attempts to add the given message ID to the offer and returns true if it - * was added. - */ - boolean writeMessageId(MessageId m) throws IOException; - - /** Finishes writing the offer. */ - void finish() throws IOException; -} diff --git a/api/net/sf/briar/api/protocol/writers/ProtocolWriterFactory.java b/api/net/sf/briar/api/protocol/writers/ProtocolWriterFactory.java deleted file mode 100644 index a0e88b4faad3249d360b46bd40f7930160b647db..0000000000000000000000000000000000000000 --- a/api/net/sf/briar/api/protocol/writers/ProtocolWriterFactory.java +++ /dev/null @@ -1,18 +0,0 @@ -package net.sf.briar.api.protocol.writers; - -import java.io.OutputStream; - -public interface ProtocolWriterFactory { - - AckWriter createAckWriter(OutputStream out); - - BatchWriter createBatchWriter(OutputStream out); - - OfferWriter createOfferWriter(OutputStream out); - - RequestWriter createRequestWriter(OutputStream out); - - SubscriptionUpdateWriter createSubscriptionUpdateWriter(OutputStream out); - - TransportUpdateWriter createTransportUpdateWriter(OutputStream out); -} diff --git a/api/net/sf/briar/api/protocol/writers/RequestWriter.java b/api/net/sf/briar/api/protocol/writers/RequestWriter.java deleted file mode 100644 index 01f4b0c1179a1f6c17c76b54e38d98ba0497279b..0000000000000000000000000000000000000000 --- a/api/net/sf/briar/api/protocol/writers/RequestWriter.java +++ /dev/null @@ -1,11 +0,0 @@ -package net.sf.briar.api.protocol.writers; - -import java.io.IOException; -import java.util.BitSet; - -/** An interface for creating a request packet. */ -public interface RequestWriter { - - /** Writes the contents of the request. */ - void writeRequest(BitSet b, int length) throws IOException; -} diff --git a/api/net/sf/briar/api/protocol/writers/SubscriptionUpdateWriter.java b/api/net/sf/briar/api/protocol/writers/SubscriptionUpdateWriter.java deleted file mode 100644 index 87c0ae5cd7e658e60a614e0161cbe334c80e8cd5..0000000000000000000000000000000000000000 --- a/api/net/sf/briar/api/protocol/writers/SubscriptionUpdateWriter.java +++ /dev/null @@ -1,14 +0,0 @@ -package net.sf.briar.api.protocol.writers; - -import java.io.IOException; -import java.util.Map; - -import net.sf.briar.api.protocol.Group; - -/** An interface for creating a subscription update. */ -public interface SubscriptionUpdateWriter { - - /** Writes the contents of the update. */ - void writeSubscriptions(Map<Group, Long> subs, long timestamp) - throws IOException; -} diff --git a/api/net/sf/briar/api/protocol/writers/TransportUpdateWriter.java b/api/net/sf/briar/api/protocol/writers/TransportUpdateWriter.java deleted file mode 100644 index e3ea60838079576b370c81daede0fc64d00ebd38..0000000000000000000000000000000000000000 --- a/api/net/sf/briar/api/protocol/writers/TransportUpdateWriter.java +++ /dev/null @@ -1,14 +0,0 @@ -package net.sf.briar.api.protocol.writers; - -import java.io.IOException; -import java.util.Collection; - -import net.sf.briar.api.protocol.Transport; - -/** An interface for creating a transport update. */ -public interface TransportUpdateWriter { - - /** Writes the contents of the update. */ - void writeTransports(Collection<Transport> transports, long timestamp) - throws IOException; -} diff --git a/api/net/sf/briar/api/serial/SerialComponent.java b/api/net/sf/briar/api/serial/SerialComponent.java index b68a499d43ec78a8cd00f337002682b039878882..f880083b9e5f5b658b01b76c7d6b5020915a6003 100644 --- a/api/net/sf/briar/api/serial/SerialComponent.java +++ b/api/net/sf/briar/api/serial/SerialComponent.java @@ -6,7 +6,7 @@ public interface SerialComponent { int getSerialisedListStartLength(); - int getSerialisedUniqueIdLength(int id); - int getSerialisedStructIdLength(int id); + + int getSerialisedUniqueIdLength(int id); } diff --git a/components/net/sf/briar/db/Database.java b/components/net/sf/briar/db/Database.java index 02dc62d98b72d934c3588ba7ad2185912be283f0..160b172e01e50c301b2bf5be207f0b7a247fb842 100644 --- a/components/net/sf/briar/db/Database.java +++ b/components/net/sf/briar/db/Database.java @@ -178,7 +178,8 @@ interface Database<T> { * <p> * Locking: contact read, messageStatus read. */ - Collection<BatchId> getBatchesToAck(T txn, ContactId c) throws DbException; + Collection<BatchId> getBatchesToAck(T txn, ContactId c, int maxBatches) + throws DbException; /** * Returns the configuration for the given transport. @@ -315,6 +316,16 @@ interface Database<T> { */ int getNumberOfSendableChildren(T txn, MessageId m) throws DbException; + /** + * Returns the IDs of some messages that are eligible to be sent to the + * given contact, up to the given number of messages. + * <p> + * Locking: contact read, message read, messageStatus read, + * subscription read. + */ + Collection<MessageId> getOfferableMessages(T txn, ContactId c, + int maxMessages) throws DbException; + /** * Returns the IDs of the oldest messages in the database, with a total * size less than or equal to the given size. @@ -361,16 +372,6 @@ interface Database<T> { */ int getSendability(T txn, MessageId m) throws DbException; - /** - * Returns the IDs of some messages that are eligible to be sent to the - * given contact. - * <p> - * Locking: contact read, message read, messageStatus read, - * subscription read. - */ - Collection<MessageId> getSendableMessages(T txn, ContactId c) - throws DbException; - /** * Returns the IDs of some messages that are eligible to be sent to the * given contact, with a total size less than or equal to the given size. diff --git a/components/net/sf/briar/db/DatabaseComponentImpl.java b/components/net/sf/briar/db/DatabaseComponentImpl.java index fff0e963fa1f212260c2f1f772ec33f12e28f173..f7d2d31bfcc19711a5eeaf2bca3749cf388677a4 100644 --- a/components/net/sf/briar/db/DatabaseComponentImpl.java +++ b/components/net/sf/briar/db/DatabaseComponentImpl.java @@ -20,7 +20,6 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.logging.Level; import java.util.logging.Logger; -import net.sf.briar.api.Bytes; import net.sf.briar.api.ContactId; import net.sf.briar.api.Rating; import net.sf.briar.api.TransportConfig; @@ -51,17 +50,14 @@ import net.sf.briar.api.protocol.GroupId; import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.Offer; +import net.sf.briar.api.protocol.PacketFactory; +import net.sf.briar.api.protocol.RawBatch; +import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportUpdate; -import net.sf.briar.api.protocol.writers.AckWriter; -import net.sf.briar.api.protocol.writers.BatchWriter; -import net.sf.briar.api.protocol.writers.OfferWriter; -import net.sf.briar.api.protocol.writers.RequestWriter; -import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter; -import net.sf.briar.api.protocol.writers.TransportUpdateWriter; import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.util.ByteUtils; @@ -105,6 +101,7 @@ DatabaseCleaner.Callback { private final Database<T> db; private final DatabaseCleaner cleaner; private final ShutdownManager shutdown; + private final PacketFactory packetFactory; private final Collection<DatabaseListener> listeners = new CopyOnWriteArrayList<DatabaseListener>(); @@ -119,10 +116,11 @@ DatabaseCleaner.Callback { @Inject DatabaseComponentImpl(Database<T> db, DatabaseCleaner cleaner, - ShutdownManager shutdown) { + ShutdownManager shutdown, PacketFactory packetFactory) { this.db = db; this.cleaner = cleaner; this.shutdown = shutdown; + this.packetFactory = packetFactory; } public void open(boolean resume) throws DbException, IOException { @@ -265,7 +263,7 @@ DatabaseCleaner.Callback { if(sendability > 0) updateAncestorSendability(txn, id, true); // Count the bytes stored synchronized(spaceLock) { - bytesStoredSinceLastCheck += m.getLength(); + bytesStoredSinceLastCheck += m.getSerialised().length; } } return stored; @@ -373,7 +371,7 @@ DatabaseCleaner.Callback { else db.setStatus(txn, c, id, Status.NEW); // Count the bytes stored synchronized(spaceLock) { - bytesStoredSinceLastCheck += m.getLength(); + bytesStoredSinceLastCheck += m.getSerialised().length; } return true; } @@ -415,17 +413,16 @@ DatabaseCleaner.Callback { return i; } - public boolean generateAck(ContactId c, AckWriter a) throws DbException, - IOException { + public Ack generateAck(ContactId c, int maxBatches) throws DbException { + Collection<BatchId> acked; contactLock.readLock().lock(); try { if(!containsContact(c)) throw new NoSuchContactException(); - Collection<BatchId> acks, sent = new ArrayList<BatchId>(); messageStatusLock.readLock().lock(); try { T txn = db.startTransaction(); try { - acks = db.getBatchesToAck(txn, c); + acked = db.getBatchesToAck(txn, c, maxBatches); db.commitTransaction(txn); } catch(DbException e) { db.abortTransaction(txn); @@ -434,20 +431,14 @@ DatabaseCleaner.Callback { } finally { messageStatusLock.readLock().unlock(); } - for(BatchId b : acks) { - if(!a.writeBatchId(b)) break; - sent.add(b); - } - // Record the contents of the ack, unless it's empty - if(sent.isEmpty()) return false; - a.finish(); + if(acked.isEmpty()) return null; + // Record the contents of the ack messageStatusLock.writeLock().lock(); try { T txn = db.startTransaction(); try { - db.removeBatchesToAck(txn, c, sent); + db.removeBatchesToAck(txn, c, acked); db.commitTransaction(txn); - return true; } catch(DbException e) { db.abortTransaction(txn); throw e; @@ -458,12 +449,14 @@ DatabaseCleaner.Callback { } finally { contactLock.readLock().unlock(); } + return packetFactory.createAck(acked); } - public boolean generateBatch(ContactId c, BatchWriter b) throws DbException, - IOException { + public RawBatch generateBatch(ContactId c, int capacity) + throws DbException { Collection<MessageId> ids; - Collection<Bytes> messages = new ArrayList<Bytes>(); + List<byte[]> messages = new ArrayList<byte[]>(); + RawBatch b; // Get some sendable messages from the database contactLock.readLock().lock(); try { @@ -476,10 +469,9 @@ DatabaseCleaner.Callback { try { T txn = db.startTransaction(); try { - int capacity = b.getCapacity(); ids = db.getSendableMessages(txn, c, capacity); for(MessageId m : ids) { - messages.add(new Bytes(db.getMessage(txn, m))); + messages.add(db.getMessage(txn, m)); } db.commitTransaction(txn); } catch(DbException e) { @@ -492,40 +484,14 @@ DatabaseCleaner.Callback { } finally { messageStatusLock.readLock().unlock(); } - } finally { - messageLock.readLock().unlock(); - } - } finally { - contactLock.readLock().unlock(); - } - if(ids.isEmpty()) return false; - writeAndRecordBatch(c, b, ids, messages); - return true; - } - - private void writeAndRecordBatch(ContactId c, BatchWriter b, - Collection<MessageId> ids, Collection<Bytes> messages) - throws DbException, IOException { - assert !ids.isEmpty(); - assert !messages.isEmpty(); - assert ids.size() == messages.size(); - // Add the messages to the batch - for(Bytes raw : messages) { - boolean written = b.writeMessage(raw.getBytes()); - assert written; - } - BatchId id = b.finish(); - // Record the contents of the batch - contactLock.readLock().lock(); - try { - if(!containsContact(c)) throw new NoSuchContactException(); - messageLock.readLock().lock(); - try { + if(messages.isEmpty()) return null; + messages = Collections.unmodifiableList(messages); + b = packetFactory.createBatch(messages); messageStatusLock.writeLock().lock(); try { T txn = db.startTransaction(); try { - db.addOutstandingBatch(txn, c, id, ids); + db.addOutstandingBatch(txn, c, b.getId(), ids); db.commitTransaction(txn); } catch(DbException e) { db.abortTransaction(txn); @@ -540,12 +506,14 @@ DatabaseCleaner.Callback { } finally { contactLock.readLock().unlock(); } + return b; } - public boolean generateBatch(ContactId c, BatchWriter b, - Collection<MessageId> requested) throws DbException, IOException { + public RawBatch generateBatch(ContactId c, int capacity, + Collection<MessageId> requested) throws DbException { Collection<MessageId> ids = new ArrayList<MessageId>(); - Collection<Bytes> messages = new ArrayList<Bytes>(); + List<byte[]> messages = new ArrayList<byte[]>(); + RawBatch b; // Get some sendable messages from the database contactLock.readLock().lock(); try { @@ -558,15 +526,15 @@ DatabaseCleaner.Callback { try { T txn = db.startTransaction(); try { - int capacity = b.getCapacity(); Iterator<MessageId> it = requested.iterator(); while(it.hasNext()) { MessageId m = it.next(); byte[] raw = db.getMessageIfSendable(txn, c, m); if(raw != null) { if(raw.length > capacity) break; + messages.add(raw); ids.add(m); - messages.add(new Bytes(raw)); + capacity -= raw.length; } it.remove(); } @@ -581,21 +549,34 @@ DatabaseCleaner.Callback { } finally { messageStatusLock.readLock().unlock(); } + if(messages.isEmpty()) return null; + messages = Collections.unmodifiableList(messages); + b = packetFactory.createBatch(messages); + messageStatusLock.writeLock().lock(); + try { + T txn = db.startTransaction(); + try { + db.addOutstandingBatch(txn, c, b.getId(), ids); + db.commitTransaction(txn); + } catch(DbException e) { + db.abortTransaction(txn); + throw e; + } + } finally { + messageStatusLock.writeLock().unlock(); + } } finally { messageLock.readLock().unlock(); } } finally { contactLock.readLock().unlock(); } - if(ids.isEmpty()) return false; - writeAndRecordBatch(c, b, ids, messages); - return true; + return b; } - public Collection<MessageId> generateOffer(ContactId c, OfferWriter o) - throws DbException, IOException { - Collection<MessageId> sendable; - List<MessageId> sent = new ArrayList<MessageId>(); + public Offer generateOffer(ContactId c, int maxMessages) + throws DbException { + Collection<MessageId> offered; contactLock.readLock().lock(); try { if(!containsContact(c)) throw new NoSuchContactException(); @@ -605,7 +586,7 @@ DatabaseCleaner.Callback { try { T txn = db.startTransaction(); try { - sendable = db.getSendableMessages(txn, c); + offered = db.getOfferableMessages(txn, c, maxMessages); db.commitTransaction(txn); } catch(DbException e) { db.abortTransaction(txn); @@ -620,33 +601,41 @@ DatabaseCleaner.Callback { } finally { contactLock.readLock().unlock(); } - for(MessageId m : sendable) { - if(!o.writeMessageId(m)) break; - sent.add(m); - } - if(!sent.isEmpty()) o.finish(); - return Collections.unmodifiableList(sent); + return packetFactory.createOffer(offered); } - public void generateSubscriptionUpdate(ContactId c, - SubscriptionUpdateWriter s) throws DbException, IOException { - Map<Group, Long> subs = null; - long timestamp = 0L; + public SubscriptionUpdate generateSubscriptionUpdate(ContactId c) + throws DbException { + boolean due; + Map<Group, Long> subs; + long timestamp; contactLock.readLock().lock(); try { if(!containsContact(c)) throw new NoSuchContactException(); - subscriptionLock.writeLock().lock(); + subscriptionLock.readLock().lock(); try { T txn = db.startTransaction(); try { // Work out whether an update is due long modified = db.getSubscriptionsModified(txn, c); long sent = db.getSubscriptionsSent(txn, c); - if(modified >= sent || updateIsDue(sent)) { - subs = db.getVisibleSubscriptions(txn, c); - timestamp = System.currentTimeMillis(); - db.setSubscriptionsSent(txn, c, timestamp); - } + due = modified >= sent || updateIsDue(sent); + db.commitTransaction(txn); + } catch(DbException e) { + db.abortTransaction(txn); + throw e; + } + } finally { + subscriptionLock.readLock().unlock(); + } + if(!due) return null; + subscriptionLock.writeLock().lock(); + try { + T txn = db.startTransaction(); + try { + subs = db.getVisibleSubscriptions(txn, c); + timestamp = System.currentTimeMillis(); + db.setSubscriptionsSent(txn, c, timestamp); db.commitTransaction(txn); } catch(DbException e) { db.abortTransaction(txn); @@ -658,7 +647,7 @@ DatabaseCleaner.Callback { } finally { contactLock.readLock().unlock(); } - if(subs != null) s.writeSubscriptions(subs, timestamp); + return packetFactory.createSubscriptionUpdate(subs, timestamp); } private boolean updateIsDue(long sent) { @@ -666,25 +655,38 @@ DatabaseCleaner.Callback { return now - sent >= DatabaseConstants.MAX_UPDATE_INTERVAL; } - public void generateTransportUpdate(ContactId c, TransportUpdateWriter t) - throws DbException, IOException { - Collection<Transport> transports = null; - long timestamp = 0L; + public TransportUpdate generateTransportUpdate(ContactId c) + throws DbException { + boolean due; + Collection<Transport> transports; + long timestamp; contactLock.readLock().lock(); try { if(!containsContact(c)) throw new NoSuchContactException(); - transportLock.writeLock().lock(); + transportLock.readLock().lock(); try { T txn = db.startTransaction(); try { // Work out whether an update is due long modified = db.getTransportsModified(txn); long sent = db.getTransportsSent(txn, c); - if(modified >= sent || updateIsDue(sent)) { - transports = db.getLocalTransports(txn); - timestamp = System.currentTimeMillis(); - db.setTransportsSent(txn, c, timestamp); - } + due = modified >= sent || updateIsDue(sent); + db.commitTransaction(txn); + } catch(DbException e) { + db.abortTransaction(txn); + throw e; + } + } finally { + transportLock.readLock().unlock(); + } + if(!due) return null; + transportLock.writeLock().lock(); + try { + T txn = db.startTransaction(); + try { + transports = db.getLocalTransports(txn); + timestamp = System.currentTimeMillis(); + db.setTransportsSent(txn, c, timestamp); db.commitTransaction(txn); } catch(DbException e) { db.abortTransaction(txn); @@ -696,7 +698,7 @@ DatabaseCleaner.Callback { } finally { contactLock.readLock().unlock(); } - if(transports != null) t.writeTransports(transports, timestamp); + return packetFactory.createTransportUpdate(transports, timestamp); } public TransportConfig getConfig(TransportId t) throws DbException { @@ -1119,8 +1121,7 @@ DatabaseCleaner.Callback { return anyStored; } - public void receiveOffer(ContactId c, Offer o, RequestWriter r) - throws DbException, IOException { + public Request receiveOffer(ContactId c, Offer o) throws DbException { Collection<MessageId> offered; BitSet request; contactLock.readLock().lock(); @@ -1161,7 +1162,7 @@ DatabaseCleaner.Callback { } finally { contactLock.readLock().unlock(); } - r.writeRequest(request, offered.size()); + return packetFactory.createRequest(request, offered.size()); } public void receiveSubscriptionUpdate(ContactId c, SubscriptionUpdate s) diff --git a/components/net/sf/briar/db/DatabaseModule.java b/components/net/sf/briar/db/DatabaseModule.java index ab58a5efa595395db09bab9bc44133f9b5554447..dffc491c1a03e0818f2ef1f8cf741d1e7c883b33 100644 --- a/components/net/sf/briar/db/DatabaseModule.java +++ b/components/net/sf/briar/db/DatabaseModule.java @@ -10,6 +10,7 @@ import net.sf.briar.api.db.DatabaseMaxSize; import net.sf.briar.api.db.DatabasePassword; import net.sf.briar.api.lifecycle.ShutdownManager; import net.sf.briar.api.protocol.GroupFactory; +import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.transport.ConnectionContextFactory; import net.sf.briar.api.transport.ConnectionWindowFactory; @@ -36,7 +37,9 @@ public class DatabaseModule extends AbstractModule { @Provides @Singleton DatabaseComponent getDatabaseComponent(Database<Connection> db, - DatabaseCleaner cleaner, ShutdownManager shutdown) { - return new DatabaseComponentImpl<Connection>(db, cleaner, shutdown); + DatabaseCleaner cleaner, ShutdownManager shutdown, + PacketFactory packetFactory) { + return new DatabaseComponentImpl<Connection>(db, cleaner, shutdown, + packetFactory); } } diff --git a/components/net/sf/briar/db/JdbcDatabase.java b/components/net/sf/briar/db/JdbcDatabase.java index 9018ddb7d74f21bbc785e851e9901edb86c5401e..441dd7b3c95ce5619967c2687c46b1a91d9d5895 100644 --- a/components/net/sf/briar/db/JdbcDatabase.java +++ b/components/net/sf/briar/db/JdbcDatabase.java @@ -612,10 +612,11 @@ abstract class JdbcDatabase implements Database<Connection> { else ps.setBytes(4, m.getAuthor().getBytes()); ps.setString(5, m.getSubject()); ps.setLong(6, m.getTimestamp()); - ps.setInt(7, m.getLength()); + byte[] raw = m.getSerialised(); + ps.setInt(7, raw.length); ps.setInt(8, m.getBodyStart()); ps.setInt(9, m.getBodyLength()); - ps.setBytes(10, m.getSerialised()); + ps.setBytes(10, raw); int affected = ps.executeUpdate(); if(affected != 1) throw new DbStateException(); ps.close(); @@ -700,10 +701,11 @@ abstract class JdbcDatabase implements Database<Connection> { else ps.setBytes(2, m.getParent().getBytes()); ps.setString(3, m.getSubject()); ps.setLong(4, m.getTimestamp()); - ps.setInt(5, m.getLength()); + byte[] raw = m.getSerialised(); + ps.setInt(5, raw.length); ps.setInt(6, m.getBodyStart()); ps.setInt(7, m.getBodyLength()); - ps.setBytes(8, m.getSerialised()); + ps.setBytes(8, raw); ps.setInt(9, c.getInt()); int affected = ps.executeUpdate(); if(affected != 1) throw new DbStateException(); @@ -889,15 +891,17 @@ abstract class JdbcDatabase implements Database<Connection> { } } - public Collection<BatchId> getBatchesToAck(Connection txn, ContactId c) - throws DbException { + public Collection<BatchId> getBatchesToAck(Connection txn, ContactId c, + int maxBatches) throws DbException { PreparedStatement ps = null; ResultSet rs = null; try { String sql = "SELECT batchId FROM batchesToAck" - + " WHERE contactId = ?"; + + " WHERE contactId = ?" + + " LIMIT ?"; ps = txn.prepareStatement(sql); ps.setInt(1, c.getInt()); + ps.setInt(2, maxBatches); rs = ps.executeQuery(); List<BatchId> ids = new ArrayList<BatchId>(); while(rs.next()) ids.add(new BatchId(rs.getBytes(1))); @@ -1517,8 +1521,8 @@ abstract class JdbcDatabase implements Database<Connection> { } } - public Collection<MessageId> getSendableMessages(Connection txn, - ContactId c) throws DbException { + public Collection<MessageId> getOfferableMessages(Connection txn, + ContactId c, int maxMessages) throws DbException { PreparedStatement ps = null; ResultSet rs = null; try { @@ -1526,15 +1530,19 @@ abstract class JdbcDatabase implements Database<Connection> { String sql = "SELECT messages.messageId FROM messages" + " JOIN statuses ON messages.messageId = statuses.messageId" + " WHERE messages.contactId = ? AND status = ?" - + " ORDER BY timestamp"; + + " ORDER BY timestamp" + + " LIMIT ?"; ps = txn.prepareStatement(sql); ps.setInt(1, c.getInt()); ps.setShort(2, (short) Status.NEW.ordinal()); + ps.setInt(3, maxMessages); rs = ps.executeQuery(); List<MessageId> ids = new ArrayList<MessageId>(); while(rs.next()) ids.add(new MessageId(rs.getBytes(2))); rs.close(); ps.close(); + if(ids.size() == maxMessages) + return Collections.unmodifiableList(ids); // Do we have any sendable group messages? sql = "SELECT m.messageId FROM messages AS m" + " JOIN contactSubscriptions AS cs" @@ -1547,10 +1555,12 @@ abstract class JdbcDatabase implements Database<Connection> { + " AND timestamp >= start" + " AND status = ?" + " AND sendability > ZERO()" - + " ORDER BY timestamp"; + + " ORDER BY timestamp" + + " LIMIT ?"; ps = txn.prepareStatement(sql); ps.setInt(1, c.getInt()); ps.setShort(2, (short) Status.NEW.ordinal()); + ps.setInt(3, maxMessages - ids.size()); rs = ps.executeQuery(); while(rs.next()) ids.add(new MessageId(rs.getBytes(2))); rs.close(); diff --git a/components/net/sf/briar/protocol/AckFactory.java b/components/net/sf/briar/protocol/AckFactory.java deleted file mode 100644 index 9c574db5914b6cc988e1ff8909f17efef8c87f31..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/AckFactory.java +++ /dev/null @@ -1,11 +0,0 @@ -package net.sf.briar.protocol; - -import java.util.Collection; - -import net.sf.briar.api.protocol.Ack; -import net.sf.briar.api.protocol.BatchId; - -interface AckFactory { - - Ack createAck(Collection<BatchId> acked); -} diff --git a/components/net/sf/briar/protocol/AckFactoryImpl.java b/components/net/sf/briar/protocol/AckFactoryImpl.java deleted file mode 100644 index f08715c8f50c734784aa74e4dbe1815f7ddc4067..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/AckFactoryImpl.java +++ /dev/null @@ -1,13 +0,0 @@ -package net.sf.briar.protocol; - -import java.util.Collection; - -import net.sf.briar.api.protocol.Ack; -import net.sf.briar.api.protocol.BatchId; - -class AckFactoryImpl implements AckFactory { - - public Ack createAck(Collection<BatchId> acked) { - return new AckImpl(acked); - } -} diff --git a/components/net/sf/briar/protocol/AckReader.java b/components/net/sf/briar/protocol/AckReader.java index ebfa0ab9b343ff5a66d1136f94ef5c7bc793d03d..51bc713fe7dc3bd4b234acd4d55be6ffa9c12063 100644 --- a/components/net/sf/briar/protocol/AckReader.java +++ b/components/net/sf/briar/protocol/AckReader.java @@ -6,6 +6,7 @@ import java.util.Collection; import net.sf.briar.api.FormatException; import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.BatchId; +import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.UniqueId; @@ -16,11 +17,11 @@ import net.sf.briar.api.serial.Reader; class AckReader implements ObjectReader<Ack> { - private final AckFactory ackFactory; + private final PacketFactory packetFactory; private final ObjectReader<BatchId> batchIdReader; - AckReader(AckFactory ackFactory) { - this.ackFactory = ackFactory; + AckReader(PacketFactory packetFactory) { + this.packetFactory = packetFactory; batchIdReader = new BatchIdReader(); } @@ -36,7 +37,7 @@ class AckReader implements ObjectReader<Ack> { r.removeObjectReader(Types.BATCH_ID); r.removeConsumer(counting); // Build and return the ack - return ackFactory.createAck(batches); + return packetFactory.createAck(batches); } private static class BatchIdReader implements ObjectReader<BatchId> { diff --git a/components/net/sf/briar/protocol/MessageImpl.java b/components/net/sf/briar/protocol/MessageImpl.java index eece7bde294b471354b607222732eef661173e95..ab770ad6203c2804b37197ca7311ed9a7bb2e22d 100644 --- a/components/net/sf/briar/protocol/MessageImpl.java +++ b/components/net/sf/briar/protocol/MessageImpl.java @@ -59,10 +59,6 @@ class MessageImpl implements Message { return timestamp; } - public int getLength() { - return raw.length; - } - public byte[] getSerialised() { return raw; } diff --git a/components/net/sf/briar/protocol/OfferFactory.java b/components/net/sf/briar/protocol/OfferFactory.java deleted file mode 100644 index 6f19d4e29192edd0eadd298b67d398f050391337..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/OfferFactory.java +++ /dev/null @@ -1,11 +0,0 @@ -package net.sf.briar.protocol; - -import java.util.Collection; - -import net.sf.briar.api.protocol.MessageId; -import net.sf.briar.api.protocol.Offer; - -interface OfferFactory { - - Offer createOffer(Collection<MessageId> offered); -} diff --git a/components/net/sf/briar/protocol/OfferFactoryImpl.java b/components/net/sf/briar/protocol/OfferFactoryImpl.java deleted file mode 100644 index 075527d14768faa86d022cfe308f8f7eda203b4b..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/OfferFactoryImpl.java +++ /dev/null @@ -1,13 +0,0 @@ -package net.sf.briar.protocol; - -import java.util.Collection; - -import net.sf.briar.api.protocol.MessageId; -import net.sf.briar.api.protocol.Offer; - -class OfferFactoryImpl implements OfferFactory { - - public Offer createOffer(Collection<MessageId> offered) { - return new OfferImpl(offered); - } -} diff --git a/components/net/sf/briar/protocol/OfferReader.java b/components/net/sf/briar/protocol/OfferReader.java index e00c8207be2c1b36710ed97d60df78d226da7952..a476886df714b04a374f59d639cfcffb2acfd017 100644 --- a/components/net/sf/briar/protocol/OfferReader.java +++ b/components/net/sf/briar/protocol/OfferReader.java @@ -5,6 +5,7 @@ import java.util.Collection; import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.Offer; +import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Types; import net.sf.briar.api.serial.Consumer; @@ -15,12 +16,12 @@ import net.sf.briar.api.serial.Reader; class OfferReader implements ObjectReader<Offer> { private final ObjectReader<MessageId> messageIdReader; - private final OfferFactory offerFactory; + private final PacketFactory packetFactory; OfferReader(ObjectReader<MessageId> messageIdReader, - OfferFactory offerFactory) { + PacketFactory packetFactory) { this.messageIdReader = messageIdReader; - this.offerFactory = offerFactory; + this.packetFactory = packetFactory; } public Offer readObject(Reader r) throws IOException { @@ -35,6 +36,6 @@ class OfferReader implements ObjectReader<Offer> { r.removeObjectReader(Types.MESSAGE_ID); r.removeConsumer(counting); // Build and return the offer - return offerFactory.createOffer(messages); + return packetFactory.createOffer(messages); } } diff --git a/components/net/sf/briar/protocol/PacketFactoryImpl.java b/components/net/sf/briar/protocol/PacketFactoryImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..e6239a69fd75420df693d6152bb16c74d5a195ac --- /dev/null +++ b/components/net/sf/briar/protocol/PacketFactoryImpl.java @@ -0,0 +1,59 @@ +package net.sf.briar.protocol; + +import java.util.BitSet; +import java.util.Collection; +import java.util.Map; + +import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.api.crypto.MessageDigest; +import net.sf.briar.api.protocol.Ack; +import net.sf.briar.api.protocol.BatchId; +import net.sf.briar.api.protocol.Group; +import net.sf.briar.api.protocol.MessageId; +import net.sf.briar.api.protocol.Offer; +import net.sf.briar.api.protocol.PacketFactory; +import net.sf.briar.api.protocol.RawBatch; +import net.sf.briar.api.protocol.Request; +import net.sf.briar.api.protocol.SubscriptionUpdate; +import net.sf.briar.api.protocol.Transport; +import net.sf.briar.api.protocol.TransportUpdate; + +import com.google.inject.Inject; + +class PacketFactoryImpl implements PacketFactory { + + private final CryptoComponent crypto; + + @Inject + PacketFactoryImpl(CryptoComponent crypto) { + this.crypto = crypto; + } + + public Ack createAck(Collection<BatchId> acked) { + return new AckImpl(acked); + } + + public RawBatch createBatch(Collection<byte[]> messages) { + MessageDigest messageDigest = crypto.getMessageDigest(); + for(byte[] raw : messages) messageDigest.update(raw); + return new RawBatchImpl(new BatchId(messageDigest.digest()), messages); + } + + public Offer createOffer(Collection<MessageId> offered) { + return new OfferImpl(offered); + } + + public Request createRequest(BitSet requested, int length) { + return new RequestImpl(requested, length); + } + + public SubscriptionUpdate createSubscriptionUpdate(Map<Group, Long> subs, + long timestamp) { + return new SubscriptionUpdateImpl(subs, timestamp); + } + + public TransportUpdate createTransportUpdate( + Collection<Transport> transports, long timestamp) { + return new TransportUpdateImpl(transports, timestamp); + } +} diff --git a/components/net/sf/briar/protocol/ProtocolModule.java b/components/net/sf/briar/protocol/ProtocolModule.java index 77ab27ed09134f8d1b2b45c309ea6be1212d69c0..9527f4d86dcfe23a2a23b9678e772d677e34c78e 100644 --- a/components/net/sf/briar/protocol/ProtocolModule.java +++ b/components/net/sf/briar/protocol/ProtocolModule.java @@ -9,7 +9,9 @@ import net.sf.briar.api.protocol.GroupFactory; import net.sf.briar.api.protocol.MessageFactory; import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.Offer; +import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory; +import net.sf.briar.api.protocol.ProtocolWriterFactory; import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.TransportUpdate; @@ -23,21 +25,17 @@ public class ProtocolModule extends AbstractModule { @Override protected void configure() { - bind(AckFactory.class).to(AckFactoryImpl.class); bind(AuthorFactory.class).to(AuthorFactoryImpl.class); bind(GroupFactory.class).to(GroupFactoryImpl.class); bind(MessageFactory.class).to(MessageFactoryImpl.class); - bind(OfferFactory.class).to(OfferFactoryImpl.class); + bind(PacketFactory.class).to(PacketFactoryImpl.class); bind(ProtocolReaderFactory.class).to(ProtocolReaderFactoryImpl.class); - bind(RequestFactory.class).to(RequestFactoryImpl.class); - bind(SubscriptionUpdateFactory.class).to( - SubscriptionUpdateFactoryImpl.class); - bind(TransportUpdateFactory.class).to(TransportUpdateFactoryImpl.class); + bind(ProtocolWriterFactory.class).to(ProtocolWriterFactoryImpl.class); bind(UnverifiedBatchFactory.class).to(UnverifiedBatchFactoryImpl.class); } @Provides - ObjectReader<Ack> getAckReader(AckFactory ackFactory) { + ObjectReader<Ack> getAckReader(PacketFactory ackFactory) { return new AckReader(ackFactory); } @@ -75,25 +73,24 @@ public class ProtocolModule extends AbstractModule { @Provides ObjectReader<Offer> getOfferReader(ObjectReader<MessageId> messageIdReader, - OfferFactory offerFactory) { - return new OfferReader(messageIdReader, offerFactory); + PacketFactory packetFactory) { + return new OfferReader(messageIdReader, packetFactory); } @Provides - ObjectReader<Request> getRequestReader(RequestFactory requestFactory) { - return new RequestReader(requestFactory); + ObjectReader<Request> getRequestReader(PacketFactory packetFactory) { + return new RequestReader(packetFactory); } @Provides ObjectReader<SubscriptionUpdate> getSubscriptionReader( - ObjectReader<Group> groupReader, - SubscriptionUpdateFactory subscriptionFactory) { - return new SubscriptionUpdateReader(groupReader, subscriptionFactory); + ObjectReader<Group> groupReader, PacketFactory packetFactory) { + return new SubscriptionUpdateReader(groupReader, packetFactory); } @Provides ObjectReader<TransportUpdate> getTransportReader( - TransportUpdateFactory transportFactory) { - return new TransportUpdateReader(transportFactory); + PacketFactory packetFactory) { + return new TransportUpdateReader(packetFactory); } } diff --git a/components/net/sf/briar/protocol/ProtocolWriterFactoryImpl.java b/components/net/sf/briar/protocol/ProtocolWriterFactoryImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..1327769f3a53654e81dd1a55afb3856f4b24bbd2 --- /dev/null +++ b/components/net/sf/briar/protocol/ProtocolWriterFactoryImpl.java @@ -0,0 +1,27 @@ +package net.sf.briar.protocol; + +import java.io.OutputStream; + +import net.sf.briar.api.protocol.ProtocolWriter; +import net.sf.briar.api.protocol.ProtocolWriterFactory; +import net.sf.briar.api.serial.SerialComponent; +import net.sf.briar.api.serial.WriterFactory; + +import com.google.inject.Inject; + +class ProtocolWriterFactoryImpl implements ProtocolWriterFactory { + + private final SerialComponent serial; + private final WriterFactory writerFactory; + + @Inject + ProtocolWriterFactoryImpl(SerialComponent serial, + WriterFactory writerFactory) { + this.serial = serial; + this.writerFactory = writerFactory; + } + + public ProtocolWriter createProtocolWriter(OutputStream out) { + return new ProtocolWriterImpl(serial, writerFactory, out); + } +} diff --git a/components/net/sf/briar/protocol/ProtocolWriterImpl.java b/components/net/sf/briar/protocol/ProtocolWriterImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..9cff78109deb077482e604371ad6f545c8cba6ea --- /dev/null +++ b/components/net/sf/briar/protocol/ProtocolWriterImpl.java @@ -0,0 +1,143 @@ +package net.sf.briar.protocol; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.BitSet; +import java.util.Map.Entry; + +import net.sf.briar.api.protocol.Ack; +import net.sf.briar.api.protocol.BatchId; +import net.sf.briar.api.protocol.Group; +import net.sf.briar.api.protocol.MessageId; +import net.sf.briar.api.protocol.Offer; +import static net.sf.briar.api.protocol.ProtocolConstants.MAX_PACKET_LENGTH; +import net.sf.briar.api.protocol.ProtocolWriter; +import net.sf.briar.api.protocol.RawBatch; +import net.sf.briar.api.protocol.Request; +import net.sf.briar.api.protocol.SubscriptionUpdate; +import net.sf.briar.api.protocol.Transport; +import net.sf.briar.api.protocol.TransportUpdate; +import net.sf.briar.api.protocol.Types; +import net.sf.briar.api.serial.SerialComponent; +import net.sf.briar.api.serial.Writer; +import net.sf.briar.api.serial.WriterFactory; + +// This class is not thread-safe +class ProtocolWriterImpl implements ProtocolWriter { + + private final SerialComponent serial; + private final OutputStream out; + private final Writer w; + + ProtocolWriterImpl(SerialComponent serial, WriterFactory writerFactory, + OutputStream out) { + this.serial = serial; + this.out = out; + w = writerFactory.createWriter(out); + } + + public int getMaxBatchesForAck(long capacity) { + int packet = (int) Math.min(capacity, MAX_PACKET_LENGTH); + int overhead = serial.getSerialisedStructIdLength(Types.ACK) + + serial.getSerialisedListStartLength() + + serial.getSerialisedListEndLength(); + int idLength = serial.getSerialisedUniqueIdLength(Types.BATCH_ID); + return (packet - overhead) / idLength; + } + + public int getMaxMessagesForOffer(long capacity) { + int packet = (int) Math.min(capacity, MAX_PACKET_LENGTH); + int overhead = serial.getSerialisedStructIdLength(Types.OFFER) + + serial.getSerialisedListStartLength() + + serial.getSerialisedListEndLength(); + int idLength = serial.getSerialisedUniqueIdLength(Types.MESSAGE_ID); + return (packet - overhead) / idLength; + } + + public int getMessageCapacityForBatch(long capacity) { + int packet = (int) Math.min(capacity, MAX_PACKET_LENGTH); + int overhead = serial.getSerialisedStructIdLength(Types.BATCH) + + serial.getSerialisedListStartLength() + + serial.getSerialisedListEndLength(); + return packet - overhead; + } + + public void writeAck(Ack a) throws IOException { + w.writeStructId(Types.ACK); + w.writeListStart(); + for(BatchId b : a.getBatchIds()) { + w.writeStructId(Types.BATCH_ID); + w.writeBytes(b.getBytes()); + } + w.writeListEnd(); + } + + public void writeBatch(RawBatch b) throws IOException { + w.writeStructId(Types.BATCH); + w.writeListStart(); + for(byte[] raw : b.getMessages()) out.write(raw); + w.writeListEnd(); + } + + public void writeOffer(Offer o) throws IOException { + w.writeStructId(Types.OFFER); + w.writeListStart(); + for(MessageId m : o.getMessageIds()) { + w.writeStructId(Types.MESSAGE_ID); + w.writeBytes(m.getBytes()); + } + w.writeListEnd(); + } + + public void writeRequest(Request r) throws IOException { + BitSet b = r.getBitmap(); + int length = r.getLength(); + // If the number of bits isn't a multiple of 8, round up to a byte + int bytes = length % 8 == 0 ? length / 8 : length / 8 + 1; + byte[] bitmap = new byte[bytes]; + // I'm kind of surprised BitSet doesn't have a method for this + for(int i = 0; i < length; i++) { + if(b.get(i)) { + int offset = i / 8; + byte bit = (byte) (128 >> i % 8); + bitmap[offset] |= bit; + } + } + w.writeStructId(Types.REQUEST); + w.writeUint7((byte) (bytes * 8 - length)); + w.writeBytes(bitmap); + } + + public void writeSubscriptionUpdate(SubscriptionUpdate s) + throws IOException { + w.writeStructId(Types.SUBSCRIPTION_UPDATE); + w.writeMapStart(); + for(Entry<Group, Long> e : s.getSubscriptions().entrySet()) { + writeGroup(w, e.getKey()); + w.writeInt64(e.getValue()); + } + w.writeMapEnd(); + w.writeInt64(s.getTimestamp()); + } + + private void writeGroup(Writer w, Group g) throws IOException { + w.writeStructId(Types.GROUP); + w.writeString(g.getName()); + byte[] publicKey = g.getPublicKey(); + if(publicKey == null) w.writeNull(); + else w.writeBytes(publicKey); + } + + public void writeTransportUpdate(TransportUpdate t) throws IOException { + w.writeStructId(Types.TRANSPORT_UPDATE); + w.writeListStart(); + for(Transport p : t.getTransports()) { + w.writeStructId(Types.TRANSPORT); + w.writeBytes(p.getId().getBytes()); + w.writeInt32(p.getIndex().getInt()); + w.writeMap(p); + } + w.writeListEnd(); + w.writeInt64(t.getTimestamp()); + } +} diff --git a/components/net/sf/briar/protocol/RawBatchImpl.java b/components/net/sf/briar/protocol/RawBatchImpl.java new file mode 100644 index 0000000000000000000000000000000000000000..06da27023c94903151bf1718b871643a03105712 --- /dev/null +++ b/components/net/sf/briar/protocol/RawBatchImpl.java @@ -0,0 +1,25 @@ +package net.sf.briar.protocol; + +import java.util.Collection; + +import net.sf.briar.api.protocol.BatchId; +import net.sf.briar.api.protocol.RawBatch; + +class RawBatchImpl implements RawBatch { + + private final BatchId id; + private final Collection<byte[]> messages; + + RawBatchImpl(BatchId id, Collection<byte[]> messages) { + this.id = id; + this.messages = messages; + } + + public BatchId getId() { + return id; + } + + public Collection<byte[]> getMessages() { + return messages; + } +} diff --git a/components/net/sf/briar/protocol/RequestFactory.java b/components/net/sf/briar/protocol/RequestFactory.java deleted file mode 100644 index 005982b826a5c32ae72efceb1fe563d24d6406ec..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/RequestFactory.java +++ /dev/null @@ -1,10 +0,0 @@ -package net.sf.briar.protocol; - -import java.util.BitSet; - -import net.sf.briar.api.protocol.Request; - -interface RequestFactory { - - Request createRequest(BitSet requested); -} diff --git a/components/net/sf/briar/protocol/RequestFactoryImpl.java b/components/net/sf/briar/protocol/RequestFactoryImpl.java deleted file mode 100644 index 0c2c77cb1cd324c4f292adeba8da1125dfbd21f2..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/RequestFactoryImpl.java +++ /dev/null @@ -1,12 +0,0 @@ -package net.sf.briar.protocol; - -import java.util.BitSet; - -import net.sf.briar.api.protocol.Request; - -class RequestFactoryImpl implements RequestFactory { - - public Request createRequest(BitSet requested) { - return new RequestImpl(requested); - } -} diff --git a/components/net/sf/briar/protocol/RequestImpl.java b/components/net/sf/briar/protocol/RequestImpl.java index ddb31898aeb2f4f02ae581ce889eaebf22894d09..aeae2cd285571f30c456276b61f76832145479a8 100644 --- a/components/net/sf/briar/protocol/RequestImpl.java +++ b/components/net/sf/briar/protocol/RequestImpl.java @@ -7,12 +7,18 @@ import net.sf.briar.api.protocol.Request; class RequestImpl implements Request { private final BitSet requested; + private final int length; - RequestImpl(BitSet requested) { + RequestImpl(BitSet requested, int length) { this.requested = requested; + this.length = length; } public BitSet getBitmap() { return requested; } + + public int getLength() { + return length; + } } diff --git a/components/net/sf/briar/protocol/RequestReader.java b/components/net/sf/briar/protocol/RequestReader.java index 13ac7768c55d49ce3e33e179d104c81b3c3c9416..cab6d20d7a9f4c18926c83cee48db7609ed40264 100644 --- a/components/net/sf/briar/protocol/RequestReader.java +++ b/components/net/sf/briar/protocol/RequestReader.java @@ -3,6 +3,8 @@ package net.sf.briar.protocol; import java.io.IOException; import java.util.BitSet; +import net.sf.briar.api.FormatException; +import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.Types; @@ -13,10 +15,10 @@ import net.sf.briar.api.serial.Reader; class RequestReader implements ObjectReader<Request> { - private final RequestFactory requestFactory; + private final PacketFactory packetFactory; - RequestReader(RequestFactory requestFactory) { - this.requestFactory = requestFactory; + RequestReader(PacketFactory packetFactory) { + this.packetFactory = packetFactory; } public Request readObject(Reader r) throws IOException { @@ -26,16 +28,19 @@ class RequestReader implements ObjectReader<Request> { // Read the data r.addConsumer(counting); r.readStructId(Types.REQUEST); + int padding = r.readUint7(); + if(padding > 7) throw new FormatException(); byte[] bitmap = r.readBytes(ProtocolConstants.MAX_PACKET_LENGTH); r.removeConsumer(counting); // Convert the bitmap into a BitSet - BitSet b = new BitSet(bitmap.length * 8); + int length = bitmap.length * 8 - padding; + BitSet b = new BitSet(length); for(int i = 0; i < bitmap.length; i++) { - for(int j = 0; j < 8; j++) { + for(int j = 0; j < 8 && i * 8 + j < length; j++) { byte bit = (byte) (128 >> j); if((bitmap[i] & bit) != 0) b.set(i * 8 + j); } } - return requestFactory.createRequest(b); + return packetFactory.createRequest(b, length); } } diff --git a/components/net/sf/briar/protocol/SubscriptionUpdateFactory.java b/components/net/sf/briar/protocol/SubscriptionUpdateFactory.java deleted file mode 100644 index 5d04150ececd0f35a58bc21a2c71c781ff2af32c..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/SubscriptionUpdateFactory.java +++ /dev/null @@ -1,12 +0,0 @@ -package net.sf.briar.protocol; - -import java.util.Map; - -import net.sf.briar.api.protocol.Group; -import net.sf.briar.api.protocol.SubscriptionUpdate; - -interface SubscriptionUpdateFactory { - - SubscriptionUpdate createSubscriptions(Map<Group, Long> subs, - long timestamp); -} diff --git a/components/net/sf/briar/protocol/SubscriptionUpdateFactoryImpl.java b/components/net/sf/briar/protocol/SubscriptionUpdateFactoryImpl.java deleted file mode 100644 index aeb93b3a768f010ac0682bb4b7b86272df88c7e8..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/SubscriptionUpdateFactoryImpl.java +++ /dev/null @@ -1,14 +0,0 @@ -package net.sf.briar.protocol; - -import java.util.Map; - -import net.sf.briar.api.protocol.Group; -import net.sf.briar.api.protocol.SubscriptionUpdate; - -class SubscriptionUpdateFactoryImpl implements SubscriptionUpdateFactory { - - public SubscriptionUpdate createSubscriptions(Map<Group, Long> subs, - long timestamp) { - return new SubscriptionUpdateImpl(subs, timestamp); - } -} diff --git a/components/net/sf/briar/protocol/SubscriptionUpdateReader.java b/components/net/sf/briar/protocol/SubscriptionUpdateReader.java index fa02250ce020937c50f92d2b511a06c722b06162..74b55089887cbb9ed91c5ff942c88e0234d62b76 100644 --- a/components/net/sf/briar/protocol/SubscriptionUpdateReader.java +++ b/components/net/sf/briar/protocol/SubscriptionUpdateReader.java @@ -5,6 +5,7 @@ import java.util.Map; import net.sf.briar.api.FormatException; import net.sf.briar.api.protocol.Group; +import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.Types; @@ -16,12 +17,12 @@ import net.sf.briar.api.serial.Reader; class SubscriptionUpdateReader implements ObjectReader<SubscriptionUpdate> { private final ObjectReader<Group> groupReader; - private final SubscriptionUpdateFactory subscriptionFactory; + private final PacketFactory packetFactory; SubscriptionUpdateReader(ObjectReader<Group> groupReader, - SubscriptionUpdateFactory subscriptionFactory) { + PacketFactory packetFactory) { this.groupReader = groupReader; - this.subscriptionFactory = subscriptionFactory; + this.packetFactory = packetFactory; } public SubscriptionUpdate readObject(Reader r) throws IOException { @@ -38,6 +39,6 @@ class SubscriptionUpdateReader implements ObjectReader<SubscriptionUpdate> { if(timestamp < 0L) throw new FormatException(); r.removeConsumer(counting); // Build and return the subscription update - return subscriptionFactory.createSubscriptions(subs, timestamp); + return packetFactory.createSubscriptionUpdate(subs, timestamp); } } diff --git a/components/net/sf/briar/protocol/TransportUpdateFactory.java b/components/net/sf/briar/protocol/TransportUpdateFactory.java deleted file mode 100644 index 0be7251d21d8b88a697677c5447e562da6102be9..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/TransportUpdateFactory.java +++ /dev/null @@ -1,12 +0,0 @@ -package net.sf.briar.protocol; - -import java.util.Collection; - -import net.sf.briar.api.protocol.Transport; -import net.sf.briar.api.protocol.TransportUpdate; - -interface TransportUpdateFactory { - - TransportUpdate createTransportUpdate(Collection<Transport> transports, - long timestamp); -} diff --git a/components/net/sf/briar/protocol/TransportUpdateFactoryImpl.java b/components/net/sf/briar/protocol/TransportUpdateFactoryImpl.java deleted file mode 100644 index bf098bcd0bc5a6fb310d7d51f069d385360ccf9c..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/TransportUpdateFactoryImpl.java +++ /dev/null @@ -1,14 +0,0 @@ -package net.sf.briar.protocol; - -import java.util.Collection; - -import net.sf.briar.api.protocol.Transport; -import net.sf.briar.api.protocol.TransportUpdate; - -class TransportUpdateFactoryImpl implements TransportUpdateFactory { - - public TransportUpdate createTransportUpdate( - Collection<Transport> transports, long timestamp) { - return new TransportUpdateImpl(transports, timestamp); - } -} diff --git a/components/net/sf/briar/protocol/TransportUpdateReader.java b/components/net/sf/briar/protocol/TransportUpdateReader.java index 91b3265ac5918cd1d5562318da2d0d6d8b50fc7c..7b52cec7da57bd9a1a6eaf026e91758d0da2d525 100644 --- a/components/net/sf/briar/protocol/TransportUpdateReader.java +++ b/components/net/sf/briar/protocol/TransportUpdateReader.java @@ -7,6 +7,7 @@ import java.util.Map; import java.util.Set; import net.sf.briar.api.FormatException; +import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.TransportId; @@ -21,11 +22,11 @@ import net.sf.briar.api.serial.Reader; class TransportUpdateReader implements ObjectReader<TransportUpdate> { - private final TransportUpdateFactory transportUpdateFactory; + private final PacketFactory packetFactory; private final ObjectReader<Transport> transportReader; - TransportUpdateReader(TransportUpdateFactory transportFactory) { - this.transportUpdateFactory = transportFactory; + TransportUpdateReader(PacketFactory packetFactory) { + this.packetFactory = packetFactory; transportReader = new TransportReader(); } @@ -51,8 +52,7 @@ class TransportUpdateReader implements ObjectReader<TransportUpdate> { if(!indices.add(t.getIndex())) throw new FormatException(); } // Build and return the transport update - return transportUpdateFactory.createTransportUpdate(transports, - timestamp); + return packetFactory.createTransportUpdate(transports, timestamp); } private static class TransportReader implements ObjectReader<Transport> { diff --git a/components/net/sf/briar/protocol/writers/AckWriterImpl.java b/components/net/sf/briar/protocol/writers/AckWriterImpl.java deleted file mode 100644 index 5b107d1c29f249577e4aefaf2bbaa2416ff216cc..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/writers/AckWriterImpl.java +++ /dev/null @@ -1,64 +0,0 @@ -package net.sf.briar.protocol.writers; - -import java.io.IOException; -import java.io.OutputStream; - -import net.sf.briar.api.protocol.BatchId; -import net.sf.briar.api.protocol.ProtocolConstants; -import net.sf.briar.api.protocol.Types; -import net.sf.briar.api.protocol.writers.AckWriter; -import net.sf.briar.api.serial.SerialComponent; -import net.sf.briar.api.serial.Writer; -import net.sf.briar.api.serial.WriterFactory; - -class AckWriterImpl implements AckWriter { - - private final OutputStream out; - private final int headerLength, idLength, footerLength; - private final Writer w; - - private boolean started = false; - private int capacity = ProtocolConstants.MAX_PACKET_LENGTH; - - AckWriterImpl(OutputStream out, SerialComponent serial, - WriterFactory writerFactory) { - this.out = out; - headerLength = serial.getSerialisedStructIdLength(Types.ACK) - + serial.getSerialisedListStartLength(); - idLength = serial.getSerialisedUniqueIdLength(Types.BATCH_ID); - footerLength = serial.getSerialisedListEndLength(); - w = writerFactory.createWriter(out); - } - - public void setMaxPacketLength(int length) { - if(started) throw new IllegalStateException(); - if(length < 0 || length > ProtocolConstants.MAX_PACKET_LENGTH) - throw new IllegalArgumentException(); - capacity = length; - } - - public boolean writeBatchId(BatchId b) throws IOException { - int overhead = started ? footerLength : headerLength + footerLength; - if(capacity < idLength + overhead) return false; - if(!started) start(); - w.writeStructId(Types.BATCH_ID); - w.writeBytes(b.getBytes()); - capacity -= idLength; - return true; - } - - public void finish() throws IOException { - if(!started) start(); - w.writeListEnd(); - out.flush(); - capacity = ProtocolConstants.MAX_PACKET_LENGTH; - started = false; - } - - private void start() throws IOException { - w.writeStructId(Types.ACK); - w.writeListStart(); - capacity -= headerLength; - started = true; - } -} diff --git a/components/net/sf/briar/protocol/writers/BatchWriterImpl.java b/components/net/sf/briar/protocol/writers/BatchWriterImpl.java deleted file mode 100644 index dea10f2622c15f9312a0d9d878a2256f5918e96e..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/writers/BatchWriterImpl.java +++ /dev/null @@ -1,78 +0,0 @@ -package net.sf.briar.protocol.writers; - -import java.io.IOException; -import java.io.OutputStream; - -import net.sf.briar.api.crypto.MessageDigest; -import net.sf.briar.api.protocol.BatchId; -import net.sf.briar.api.protocol.ProtocolConstants; -import net.sf.briar.api.protocol.Types; -import net.sf.briar.api.protocol.writers.BatchWriter; -import net.sf.briar.api.serial.DigestingConsumer; -import net.sf.briar.api.serial.SerialComponent; -import net.sf.briar.api.serial.Writer; -import net.sf.briar.api.serial.WriterFactory; - -class BatchWriterImpl implements BatchWriter { - - private final OutputStream out; - private final int headerLength, footerLength; - private final Writer w; - private final MessageDigest messageDigest; - private final DigestingConsumer digestingConsumer; - - private boolean started = false; - private int capacity = ProtocolConstants.MAX_PACKET_LENGTH; - private int remaining = capacity; - - BatchWriterImpl(OutputStream out, SerialComponent serial, - WriterFactory writerFactory, MessageDigest messageDigest) { - this.out = out; - headerLength = serial.getSerialisedStructIdLength(Types.BATCH) - + serial.getSerialisedListStartLength(); - footerLength = serial.getSerialisedListEndLength(); - w = writerFactory.createWriter(this.out); - this.messageDigest = messageDigest; - digestingConsumer = new DigestingConsumer(messageDigest); - } - - public int getCapacity() { - return capacity - headerLength - footerLength; - } - - public void setMaxPacketLength(int length) { - if(started) throw new IllegalStateException(); - if(length < 0 || length > ProtocolConstants.MAX_PACKET_LENGTH) - throw new IllegalArgumentException(); - remaining = capacity = length; - } - - public boolean writeMessage(byte[] message) throws IOException { - int overhead = started ? footerLength : headerLength + footerLength; - if(remaining < message.length + overhead) return false; - if(!started) start(); - // Bypass the writer and write the raw message directly - out.write(message); - remaining -= message.length; - return true; - } - - public BatchId finish() throws IOException { - if(!started) start(); - w.writeListEnd(); - w.removeConsumer(digestingConsumer); - out.flush(); - remaining = capacity = ProtocolConstants.MAX_PACKET_LENGTH; - started = false; - return new BatchId(messageDigest.digest()); - } - - private void start() throws IOException { - messageDigest.reset(); - w.addConsumer(digestingConsumer); - w.writeStructId(Types.BATCH); - w.writeListStart(); - remaining -= headerLength; - started = true; - } -} diff --git a/components/net/sf/briar/protocol/writers/OfferWriterImpl.java b/components/net/sf/briar/protocol/writers/OfferWriterImpl.java deleted file mode 100644 index ba592f36de9f49780e8811ead3cd5db27df2e843..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/writers/OfferWriterImpl.java +++ /dev/null @@ -1,64 +0,0 @@ -package net.sf.briar.protocol.writers; - -import java.io.IOException; -import java.io.OutputStream; - -import net.sf.briar.api.protocol.MessageId; -import net.sf.briar.api.protocol.ProtocolConstants; -import net.sf.briar.api.protocol.Types; -import net.sf.briar.api.protocol.writers.OfferWriter; -import net.sf.briar.api.serial.SerialComponent; -import net.sf.briar.api.serial.Writer; -import net.sf.briar.api.serial.WriterFactory; - -class OfferWriterImpl implements OfferWriter { - - private final OutputStream out; - private final int headerLength, idLength, footerLength; - private final Writer w; - - private boolean started = false; - private int capacity = ProtocolConstants.MAX_PACKET_LENGTH; - - OfferWriterImpl(OutputStream out, SerialComponent serial, - WriterFactory writerFactory) { - this.out = out; - headerLength = serial.getSerialisedStructIdLength(Types.OFFER) - + serial.getSerialisedListStartLength(); - idLength = serial.getSerialisedUniqueIdLength(Types.MESSAGE_ID); - footerLength = serial.getSerialisedListEndLength(); - w = writerFactory.createWriter(out); - } - - public void setMaxPacketLength(int length) { - if(started) throw new IllegalStateException(); - if(length < 0 || length > ProtocolConstants.MAX_PACKET_LENGTH) - throw new IllegalArgumentException(); - capacity = length; - } - - public boolean writeMessageId(MessageId m) throws IOException { - int overhead = started ? footerLength : headerLength + footerLength; - if(capacity < idLength + overhead) return false; - if(!started) start(); - w.writeStructId(Types.MESSAGE_ID); - w.writeBytes(m.getBytes()); - capacity -= idLength; - return true; - } - - public void finish() throws IOException { - if(!started) start(); - w.writeListEnd(); - out.flush(); - capacity = ProtocolConstants.MAX_PACKET_LENGTH; - started = false; - } - - private void start() throws IOException { - w.writeStructId(Types.OFFER); - w.writeListStart(); - capacity -= headerLength; - started = true; - } -} diff --git a/components/net/sf/briar/protocol/writers/ProtocolWriterFactoryImpl.java b/components/net/sf/briar/protocol/writers/ProtocolWriterFactoryImpl.java deleted file mode 100644 index e977d2bd01b16b889949b3315434dd9dc7c926aa..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/writers/ProtocolWriterFactoryImpl.java +++ /dev/null @@ -1,57 +0,0 @@ -package net.sf.briar.protocol.writers; - -import java.io.OutputStream; - -import net.sf.briar.api.crypto.CryptoComponent; -import net.sf.briar.api.crypto.MessageDigest; -import net.sf.briar.api.protocol.writers.AckWriter; -import net.sf.briar.api.protocol.writers.BatchWriter; -import net.sf.briar.api.protocol.writers.OfferWriter; -import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; -import net.sf.briar.api.protocol.writers.RequestWriter; -import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter; -import net.sf.briar.api.protocol.writers.TransportUpdateWriter; -import net.sf.briar.api.serial.SerialComponent; -import net.sf.briar.api.serial.WriterFactory; - -import com.google.inject.Inject; - -class ProtocolWriterFactoryImpl implements ProtocolWriterFactory { - - private final MessageDigest messageDigest; - private final SerialComponent serial; - private final WriterFactory writerFactory; - - @Inject - ProtocolWriterFactoryImpl(CryptoComponent crypto, - SerialComponent serial, WriterFactory writerFactory) { - messageDigest = crypto.getMessageDigest(); - this.serial = serial; - this.writerFactory = writerFactory; - } - - public AckWriter createAckWriter(OutputStream out) { - return new AckWriterImpl(out, serial, writerFactory); - } - - public BatchWriter createBatchWriter(OutputStream out) { - return new BatchWriterImpl(out, serial, writerFactory, messageDigest); - } - - public OfferWriter createOfferWriter(OutputStream out) { - return new OfferWriterImpl(out, serial, writerFactory); - } - - public RequestWriter createRequestWriter(OutputStream out) { - return new RequestWriterImpl(out, writerFactory); - } - - public SubscriptionUpdateWriter createSubscriptionUpdateWriter( - OutputStream out) { - return new SubscriptionUpdateWriterImpl(out, writerFactory); - } - - public TransportUpdateWriter createTransportUpdateWriter(OutputStream out) { - return new TransportUpdateWriterImpl(out, writerFactory); - } -} diff --git a/components/net/sf/briar/protocol/writers/ProtocolWritersModule.java b/components/net/sf/briar/protocol/writers/ProtocolWritersModule.java deleted file mode 100644 index 7e83b36b0511020ffe133de15e4f77347271c20a..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/writers/ProtocolWritersModule.java +++ /dev/null @@ -1,13 +0,0 @@ -package net.sf.briar.protocol.writers; - -import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; - -import com.google.inject.AbstractModule; - -public class ProtocolWritersModule extends AbstractModule { - - @Override - protected void configure() { - bind(ProtocolWriterFactory.class).to(ProtocolWriterFactoryImpl.class); - } -} diff --git a/components/net/sf/briar/protocol/writers/RequestWriterImpl.java b/components/net/sf/briar/protocol/writers/RequestWriterImpl.java deleted file mode 100644 index 68b006006ffb099ede1ad39e7e9fe42b209e2c76..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/writers/RequestWriterImpl.java +++ /dev/null @@ -1,39 +0,0 @@ -package net.sf.briar.protocol.writers; - -import java.io.IOException; -import java.io.OutputStream; -import java.util.BitSet; - -import net.sf.briar.api.protocol.Types; -import net.sf.briar.api.protocol.writers.RequestWriter; -import net.sf.briar.api.serial.Writer; -import net.sf.briar.api.serial.WriterFactory; - -class RequestWriterImpl implements RequestWriter { - - private final OutputStream out; - private final Writer w; - - RequestWriterImpl(OutputStream out, WriterFactory writerFactory) { - this.out = out; - w = writerFactory.createWriter(out); - } - - public void writeRequest(BitSet b, int length) - throws IOException { - w.writeStructId(Types.REQUEST); - // If the number of bits isn't a multiple of 8, round up to a byte - int bytes = length % 8 == 0 ? length / 8 : length / 8 + 1; - byte[] bitmap = new byte[bytes]; - // I'm kind of surprised BitSet doesn't have a method for this - for(int i = 0; i < length; i++) { - if(b.get(i)) { - int offset = i / 8; - byte bit = (byte) (128 >> i % 8); - bitmap[offset] |= bit; - } - } - w.writeBytes(bitmap); - out.flush(); - } -} diff --git a/components/net/sf/briar/protocol/writers/SubscriptionUpdateWriterImpl.java b/components/net/sf/briar/protocol/writers/SubscriptionUpdateWriterImpl.java deleted file mode 100644 index 479955716f295c8da7a09412118536b44f15d6a5..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/writers/SubscriptionUpdateWriterImpl.java +++ /dev/null @@ -1,45 +0,0 @@ -package net.sf.briar.protocol.writers; - -import java.io.IOException; -import java.io.OutputStream; -import java.util.Map; -import java.util.Map.Entry; - -import net.sf.briar.api.protocol.Group; -import net.sf.briar.api.protocol.Types; -import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter; -import net.sf.briar.api.serial.Writer; -import net.sf.briar.api.serial.WriterFactory; - -class SubscriptionUpdateWriterImpl implements SubscriptionUpdateWriter { - - private final OutputStream out; - private final Writer w; - - SubscriptionUpdateWriterImpl(OutputStream out, - WriterFactory writerFactory) { - this.out = out; - w = writerFactory.createWriter(out); - } - - public void writeSubscriptions(Map<Group, Long> subs, long timestamp) - throws IOException { - w.writeStructId(Types.SUBSCRIPTION_UPDATE); - w.writeMapStart(); - for(Entry<Group, Long> e : subs.entrySet()) { - writeGroup(w, e.getKey()); - w.writeInt64(e.getValue()); - } - w.writeMapEnd(); - w.writeInt64(timestamp); - out.flush(); - } - - private void writeGroup(Writer w, Group g) throws IOException { - w.writeStructId(Types.GROUP); - w.writeString(g.getName()); - byte[] publicKey = g.getPublicKey(); - if(publicKey == null) w.writeNull(); - else w.writeBytes(publicKey); - } -} diff --git a/components/net/sf/briar/protocol/writers/TransportUpdateWriterImpl.java b/components/net/sf/briar/protocol/writers/TransportUpdateWriterImpl.java deleted file mode 100644 index fdfa0aede8d74179f8fedee19674ff8ae5b104fb..0000000000000000000000000000000000000000 --- a/components/net/sf/briar/protocol/writers/TransportUpdateWriterImpl.java +++ /dev/null @@ -1,37 +0,0 @@ -package net.sf.briar.protocol.writers; - -import java.io.IOException; -import java.io.OutputStream; -import java.util.Collection; - -import net.sf.briar.api.protocol.Transport; -import net.sf.briar.api.protocol.Types; -import net.sf.briar.api.protocol.writers.TransportUpdateWriter; -import net.sf.briar.api.serial.Writer; -import net.sf.briar.api.serial.WriterFactory; - -class TransportUpdateWriterImpl implements TransportUpdateWriter { - - private final OutputStream out; - private final Writer w; - - TransportUpdateWriterImpl(OutputStream out, WriterFactory writerFactory) { - this.out = out; - w = writerFactory.createWriter(out); - } - - public void writeTransports(Collection<Transport> transports, - long timestamp) throws IOException { - w.writeStructId(Types.TRANSPORT_UPDATE); - w.writeListStart(); - for(Transport p : transports) { - w.writeStructId(Types.TRANSPORT); - w.writeBytes(p.getId().getBytes()); - w.writeInt32(p.getIndex().getInt()); - w.writeMap(p); - } - w.writeListEnd(); - w.writeInt64(timestamp); - out.flush(); - } -} diff --git a/components/net/sf/briar/serial/SerialComponentImpl.java b/components/net/sf/briar/serial/SerialComponentImpl.java index 451a375d77cda19d1cc97d45829bb292db6d51bd..934e4b071f79cb568e8382fe540a768f2b9b352f 100644 --- a/components/net/sf/briar/serial/SerialComponentImpl.java +++ b/components/net/sf/briar/serial/SerialComponentImpl.java @@ -15,6 +15,11 @@ class SerialComponentImpl implements SerialComponent { return 1; } + public int getSerialisedStructIdLength(int id) { + if(id < 0 || id > 255) throw new IllegalArgumentException(); + return id < 32 ? 1 : 2; + } + public int getSerialisedUniqueIdLength(int id) { // Struct ID, BYTES tag, length spec, bytes return getSerialisedStructIdLength(id) + 1 @@ -22,14 +27,9 @@ class SerialComponentImpl implements SerialComponent { } private int getSerialisedLengthSpecLength(int length) { - assert length >= 0; + if(length < 0) throw new IllegalArgumentException(); if(length < 128) return 1; // Uint7 if(length < Short.MAX_VALUE) return 3; // Int16 return 5; // Int32 } - - public int getSerialisedStructIdLength(int id) { - assert id >= 0 && id <= 255; - return id < 32 ? 1 : 2; - } } diff --git a/components/net/sf/briar/transport/batch/BatchConnectionFactoryImpl.java b/components/net/sf/briar/transport/batch/BatchConnectionFactoryImpl.java index 81a215235702eb01a2f5769f81ff0baf47ce6a83..96d1d35e16327ce16a35319064283a251ce09fa0 100644 --- a/components/net/sf/briar/transport/batch/BatchConnectionFactoryImpl.java +++ b/components/net/sf/briar/transport/batch/BatchConnectionFactoryImpl.java @@ -5,8 +5,8 @@ import java.util.concurrent.Executor; import net.sf.briar.api.ContactId; import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.protocol.ProtocolReaderFactory; +import net.sf.briar.api.protocol.ProtocolWriterFactory; import net.sf.briar.api.protocol.TransportIndex; -import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; import net.sf.briar.api.transport.BatchConnectionFactory; import net.sf.briar.api.transport.BatchTransportReader; import net.sf.briar.api.transport.BatchTransportWriter; @@ -19,22 +19,22 @@ import com.google.inject.Inject; class BatchConnectionFactoryImpl implements BatchConnectionFactory { private final Executor executor; + private final DatabaseComponent db; private final ConnectionReaderFactory connReaderFactory; private final ConnectionWriterFactory connWriterFactory; - private final DatabaseComponent db; private final ProtocolReaderFactory protoReaderFactory; private final ProtocolWriterFactory protoWriterFactory; @Inject - BatchConnectionFactoryImpl(Executor executor, + BatchConnectionFactoryImpl(Executor executor, DatabaseComponent db, ConnectionReaderFactory connReaderFactory, - ConnectionWriterFactory connWriterFactory, DatabaseComponent db, + ConnectionWriterFactory connWriterFactory, ProtocolReaderFactory protoReaderFactory, ProtocolWriterFactory protoWriterFactory) { this.executor = executor; + this.db = db; this.connReaderFactory = connReaderFactory; this.connWriterFactory = connWriterFactory; - this.db = db; this.protoReaderFactory = protoReaderFactory; this.protoWriterFactory = protoWriterFactory; } @@ -42,7 +42,7 @@ class BatchConnectionFactoryImpl implements BatchConnectionFactory { public void createIncomingConnection(ConnectionContext ctx, BatchTransportReader r, byte[] tag) { final IncomingBatchConnection conn = new IncomingBatchConnection( - executor, connReaderFactory, db, protoReaderFactory, ctx, r, + executor, db, connReaderFactory, protoReaderFactory, ctx, r, tag); Runnable read = new Runnable() { public void run() { @@ -54,8 +54,8 @@ class BatchConnectionFactoryImpl implements BatchConnectionFactory { public void createOutgoingConnection(ContactId c, TransportIndex i, BatchTransportWriter w) { - final OutgoingBatchConnection conn = new OutgoingBatchConnection( - connWriterFactory, db, protoWriterFactory, c, i, w); + final OutgoingBatchConnection conn = new OutgoingBatchConnection(db, + connWriterFactory, protoWriterFactory, c, i, w); Runnable write = new Runnable() { public void run() { conn.write(); diff --git a/components/net/sf/briar/transport/batch/IncomingBatchConnection.java b/components/net/sf/briar/transport/batch/IncomingBatchConnection.java index e1a5082f062dc65e9eabd35e20209d6b20542844..326667063c0c507f5567c0d5745e960cc5d6bc27 100644 --- a/components/net/sf/briar/transport/batch/IncomingBatchConnection.java +++ b/components/net/sf/briar/transport/batch/IncomingBatchConnection.java @@ -39,8 +39,8 @@ class IncomingBatchConnection { private final Semaphore semaphore; IncomingBatchConnection(Executor executor, - ConnectionReaderFactory connFactory, - DatabaseComponent db, ProtocolReaderFactory protoFactory, + DatabaseComponent db, + ConnectionReaderFactory connFactory, ProtocolReaderFactory protoFactory, ConnectionContext ctx, BatchTransportReader reader, byte[] tag) { this.executor = executor; this.connFactory = connFactory; diff --git a/components/net/sf/briar/transport/batch/OutgoingBatchConnection.java b/components/net/sf/briar/transport/batch/OutgoingBatchConnection.java index 84c32cce5e7d429ef7ca2ffffcf686f0cd017c9c..30a178c695e8649624a7dced05b95c5c2a80d2c1 100644 --- a/components/net/sf/briar/transport/batch/OutgoingBatchConnection.java +++ b/components/net/sf/briar/transport/batch/OutgoingBatchConnection.java @@ -10,12 +10,13 @@ import java.util.logging.Logger; import net.sf.briar.api.ContactId; import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DbException; +import net.sf.briar.api.protocol.Ack; +import net.sf.briar.api.protocol.ProtocolWriter; +import net.sf.briar.api.protocol.ProtocolWriterFactory; +import net.sf.briar.api.protocol.RawBatch; +import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.TransportIndex; -import net.sf.briar.api.protocol.writers.AckWriter; -import net.sf.briar.api.protocol.writers.BatchWriter; -import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; -import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter; -import net.sf.briar.api.protocol.writers.TransportUpdateWriter; +import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.transport.BatchTransportWriter; import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionWriter; @@ -26,23 +27,23 @@ class OutgoingBatchConnection { private static final Logger LOG = Logger.getLogger(OutgoingBatchConnection.class.getName()); - private final ConnectionWriterFactory connFactory; private final DatabaseComponent db; + private final ConnectionWriterFactory connFactory; private final ProtocolWriterFactory protoFactory; private final ContactId contactId; private final TransportIndex transportIndex; - private final BatchTransportWriter writer; + private final BatchTransportWriter transport; - OutgoingBatchConnection(ConnectionWriterFactory connFactory, - DatabaseComponent db, ProtocolWriterFactory protoFactory, - ContactId contactId, TransportIndex transportIndex, - BatchTransportWriter writer) { - this.connFactory = connFactory; + OutgoingBatchConnection(DatabaseComponent db, + ConnectionWriterFactory connFactory, + ProtocolWriterFactory protoFactory, ContactId contactId, + TransportIndex transportIndex, BatchTransportWriter transport) { this.db = db; + this.connFactory = connFactory; this.protoFactory = protoFactory; this.contactId = contactId; this.transportIndex = transportIndex; - this.writer = writer; + this.transport = transport; } void write() { @@ -50,45 +51,52 @@ class OutgoingBatchConnection { ConnectionContext ctx = db.getConnectionContext(contactId, transportIndex); ConnectionWriter conn = connFactory.createConnectionWriter( - writer.getOutputStream(), writer.getCapacity(), + transport.getOutputStream(), transport.getCapacity(), ctx.getSecret()); OutputStream out = conn.getOutputStream(); + ProtocolWriter proto = protoFactory.createProtocolWriter(out); // There should be enough space for a packet long capacity = conn.getRemainingCapacity(); if(capacity < MAX_PACKET_LENGTH) throw new IOException(); // Write a transport update - TransportUpdateWriter t = - protoFactory.createTransportUpdateWriter(out); - db.generateTransportUpdate(contactId, t); + TransportUpdate t = db.generateTransportUpdate(contactId); + if(t != null) proto.writeTransportUpdate(t); // If there's space, write a subscription update capacity = conn.getRemainingCapacity(); if(capacity >= MAX_PACKET_LENGTH) { - SubscriptionUpdateWriter s = - protoFactory.createSubscriptionUpdateWriter(out); - db.generateSubscriptionUpdate(contactId, s); + SubscriptionUpdate s = db.generateSubscriptionUpdate(contactId); + if(s != null) proto.writeSubscriptionUpdate(s); } // Write acks until you can't write acks no more - AckWriter a = protoFactory.createAckWriter(out); - do { + capacity = conn.getRemainingCapacity(); + int maxBatches = proto.getMaxBatchesForAck(capacity); + Ack a = db.generateAck(contactId, maxBatches); + while(a != null) { + proto.writeAck(a); capacity = conn.getRemainingCapacity(); - int max = (int) Math.min(MAX_PACKET_LENGTH, capacity); - a.setMaxPacketLength(max); - } while(db.generateAck(contactId, a)); + maxBatches = proto.getMaxBatchesForAck(capacity); + a = db.generateAck(contactId, maxBatches); + } // Write batches until you can't write batches no more - BatchWriter b = protoFactory.createBatchWriter(out); - do { + capacity = conn.getRemainingCapacity(); + capacity = proto.getMessageCapacityForBatch(capacity); + RawBatch b = db.generateBatch(contactId, (int) capacity); + while(b != null) { + proto.writeBatch(b); capacity = conn.getRemainingCapacity(); - int max = (int) Math.min(MAX_PACKET_LENGTH, capacity); - b.setMaxPacketLength(max); - } while(db.generateBatch(contactId, b)); + capacity = proto.getMessageCapacityForBatch(capacity); + b = db.generateBatch(contactId, (int) capacity); + } + // Flush the output stream + out.flush(); } catch(DbException e) { if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); - writer.dispose(false); + transport.dispose(false); } catch(IOException e) { if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.getMessage()); - writer.dispose(false); + transport.dispose(false); } // Success - writer.dispose(true); + transport.dispose(true); } } diff --git a/components/net/sf/briar/transport/stream/IncomingStreamConnection.java b/components/net/sf/briar/transport/stream/IncomingStreamConnection.java index 006659e103babaa50ae2d700c27610c6fa64dfb9..559d9fe2631e36f66b4c82cf9cfcb947f9460793 100644 --- a/components/net/sf/briar/transport/stream/IncomingStreamConnection.java +++ b/components/net/sf/briar/transport/stream/IncomingStreamConnection.java @@ -6,7 +6,8 @@ import java.util.concurrent.Executor; import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DbException; import net.sf.briar.api.protocol.ProtocolReaderFactory; -import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; +import net.sf.briar.api.protocol.ProtocolWriterFactory; +import net.sf.briar.api.serial.SerialComponent; import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReaderFactory; @@ -19,14 +20,14 @@ class IncomingStreamConnection extends StreamConnection { private final ConnectionContext ctx; private final byte[] tag; - IncomingStreamConnection(Executor executor, - ConnectionReaderFactory connReaderFactory, - ConnectionWriterFactory connWriterFactory, DatabaseComponent db, + IncomingStreamConnection(Executor executor, DatabaseComponent db, + SerialComponent serial, ConnectionReaderFactory connReaderFactory, + ConnectionWriterFactory connWriterFactory, ProtocolReaderFactory protoReaderFactory, ProtocolWriterFactory protoWriterFactory, ConnectionContext ctx, StreamTransportConnection connection, byte[] tag) { - super(executor, connReaderFactory, connWriterFactory, db, + super(executor, db, serial, connReaderFactory, connWriterFactory, protoReaderFactory, protoWriterFactory, ctx.getContactId(), connection); this.ctx = ctx; diff --git a/components/net/sf/briar/transport/stream/OutgoingStreamConnection.java b/components/net/sf/briar/transport/stream/OutgoingStreamConnection.java index 5f1dabbbbc5e6710381f6261ad4bc1d0b08ec287..0594f32ddc97de8f425666a7d4472527a5813fe2 100644 --- a/components/net/sf/briar/transport/stream/OutgoingStreamConnection.java +++ b/components/net/sf/briar/transport/stream/OutgoingStreamConnection.java @@ -7,8 +7,9 @@ import net.sf.briar.api.ContactId; import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DbException; import net.sf.briar.api.protocol.ProtocolReaderFactory; +import net.sf.briar.api.protocol.ProtocolWriterFactory; import net.sf.briar.api.protocol.TransportIndex; -import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; +import net.sf.briar.api.serial.SerialComponent; import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReaderFactory; @@ -22,14 +23,14 @@ class OutgoingStreamConnection extends StreamConnection { private ConnectionContext ctx = null; // Locking: this - OutgoingStreamConnection(Executor executor, - ConnectionReaderFactory connReaderFactory, - ConnectionWriterFactory connWriterFactory, DatabaseComponent db, + OutgoingStreamConnection(Executor executor, DatabaseComponent db, + SerialComponent serial, ConnectionReaderFactory connReaderFactory, + ConnectionWriterFactory connWriterFactory, ProtocolReaderFactory protoReaderFactory, ProtocolWriterFactory protoWriterFactory, ContactId contactId, TransportIndex transportIndex, StreamTransportConnection connection) { - super(executor, connReaderFactory, connWriterFactory, db, + super(executor, db, serial, connReaderFactory, connWriterFactory, protoReaderFactory, protoWriterFactory, contactId, connection); this.transportIndex = transportIndex; } diff --git a/components/net/sf/briar/transport/stream/StreamConnection.java b/components/net/sf/briar/transport/stream/StreamConnection.java index 8e1cf6730c63ae9f02416af02374783e6e525171..ee138e9ef7d1677bf920e60d8b2a2b95ea632bdd 100644 --- a/components/net/sf/briar/transport/stream/StreamConnection.java +++ b/components/net/sf/briar/transport/stream/StreamConnection.java @@ -31,17 +31,14 @@ import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.Offer; import net.sf.briar.api.protocol.ProtocolReader; import net.sf.briar.api.protocol.ProtocolReaderFactory; +import net.sf.briar.api.protocol.ProtocolWriter; +import net.sf.briar.api.protocol.ProtocolWriterFactory; +import net.sf.briar.api.protocol.RawBatch; import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.UnverifiedBatch; -import net.sf.briar.api.protocol.writers.AckWriter; -import net.sf.briar.api.protocol.writers.BatchWriter; -import net.sf.briar.api.protocol.writers.OfferWriter; -import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; -import net.sf.briar.api.protocol.writers.RequestWriter; -import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter; -import net.sf.briar.api.protocol.writers.TransportUpdateWriter; +import net.sf.briar.api.serial.SerialComponent; import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionWriter; @@ -58,9 +55,10 @@ abstract class StreamConnection implements DatabaseListener { Logger.getLogger(StreamConnection.class.getName()); protected final Executor executor; + protected final DatabaseComponent db; + protected final SerialComponent serial; protected final ConnectionReaderFactory connReaderFactory; protected final ConnectionWriterFactory connWriterFactory; - protected final DatabaseComponent db; protected final ProtocolReaderFactory protoReaderFactory; protected final ProtocolWriterFactory protoWriterFactory; protected final ContactId contactId; @@ -73,16 +71,17 @@ abstract class StreamConnection implements DatabaseListener { private LinkedList<MessageId> requested = null; // Locking: this private Offer incomingOffer = null; // Locking: this - StreamConnection(Executor executor, - ConnectionReaderFactory connReaderFactory, - ConnectionWriterFactory connWriterFactory, DatabaseComponent db, + StreamConnection(Executor executor, DatabaseComponent db, + SerialComponent serial, ConnectionReaderFactory connReaderFactory, + ConnectionWriterFactory connWriterFactory, ProtocolReaderFactory protoReaderFactory, ProtocolWriterFactory protoWriterFactory, ContactId contactId, StreamTransportConnection connection) { this.executor = executor; + this.db = db; + this.serial = serial; this.connReaderFactory = connReaderFactory; this.connWriterFactory = connWriterFactory; - this.db = db; this.protoReaderFactory = protoReaderFactory; this.protoWriterFactory = protoWriterFactory; this.contactId = contactId; @@ -267,20 +266,11 @@ abstract class StreamConnection implements DatabaseListener { void write() { try { OutputStream out = createConnectionWriter().getOutputStream(); - // Create the packet writers - AckWriter ackWriter = protoWriterFactory.createAckWriter(out); - BatchWriter batchWriter = protoWriterFactory.createBatchWriter(out); - OfferWriter offerWriter = protoWriterFactory.createOfferWriter(out); - RequestWriter requestWriter = - protoWriterFactory.createRequestWriter(out); - SubscriptionUpdateWriter subscriptionUpdateWriter = - protoWriterFactory.createSubscriptionUpdateWriter(out); - TransportUpdateWriter transportUpdateWriter = - protoWriterFactory.createTransportUpdateWriter(out); + ProtocolWriter proto = protoWriterFactory.createProtocolWriter(out); // Send the initial packets: transports, subs, any waiting acks - sendTransportUpdate(transportUpdateWriter); - sendSubscriptionUpdate(subscriptionUpdateWriter); - sendAcks(ackWriter); + sendTransportUpdate(proto); + sendSubscriptionUpdate(proto); + sendAcks(proto); State state = State.SEND_OFFER; // Main loop while(true) { @@ -289,7 +279,7 @@ abstract class StreamConnection implements DatabaseListener { case SEND_OFFER: // Try to send an offer - if(sendOffer(offerWriter)) state = State.AWAIT_REQUEST; + if(sendOffer(proto)) state = State.AWAIT_REQUEST; else state = State.IDLE; break; @@ -312,16 +302,16 @@ abstract class StreamConnection implements DatabaseListener { return; } if((flags & Flags.TRANSPORTS_UPDATED) != 0) { - sendTransportUpdate(transportUpdateWriter); + sendTransportUpdate(proto); } if((flags & Flags.SUBSCRIPTIONS_UPDATED) != 0) { - sendSubscriptionUpdate(subscriptionUpdateWriter); + sendSubscriptionUpdate(proto); } if((flags & Flags.BATCH_RECEIVED) != 0) { - sendAcks(ackWriter); + sendAcks(proto); } if((flags & Flags.OFFER_RECEIVED) != 0) { - sendRequest(requestWriter); + sendRequest(proto); } if((flags & Flags.REQUEST_RECEIVED) != 0) { // Should only be received in state AWAIT_REQUEST @@ -351,16 +341,16 @@ abstract class StreamConnection implements DatabaseListener { return; } if((flags & Flags.TRANSPORTS_UPDATED) != 0) { - sendTransportUpdate(transportUpdateWriter); + sendTransportUpdate(proto); } if((flags & Flags.SUBSCRIPTIONS_UPDATED) != 0) { - sendSubscriptionUpdate(subscriptionUpdateWriter); + sendSubscriptionUpdate(proto); } if((flags & Flags.BATCH_RECEIVED) != 0) { - sendAcks(ackWriter); + sendAcks(proto); } if((flags & Flags.OFFER_RECEIVED) != 0) { - sendRequest(requestWriter); + sendRequest(proto); } if((flags & Flags.REQUEST_RECEIVED) != 0) { state = State.SEND_BATCHES; @@ -382,16 +372,16 @@ abstract class StreamConnection implements DatabaseListener { return; } if((flags & Flags.TRANSPORTS_UPDATED) != 0) { - sendTransportUpdate(transportUpdateWriter); + sendTransportUpdate(proto); } if((flags & Flags.SUBSCRIPTIONS_UPDATED) != 0) { - sendSubscriptionUpdate(subscriptionUpdateWriter); + sendSubscriptionUpdate(proto); } if((flags & Flags.BATCH_RECEIVED) != 0) { - sendAcks(ackWriter); + sendAcks(proto); } if((flags & Flags.OFFER_RECEIVED) != 0) { - sendRequest(requestWriter); + sendRequest(proto); } if((flags & Flags.REQUEST_RECEIVED) != 0) { // Should only be received in state AWAIT_REQUEST @@ -401,7 +391,7 @@ abstract class StreamConnection implements DatabaseListener { // Ignored in this state } // Try to send a batch - if(!sendBatch(batchWriter)) state = State.SEND_OFFER; + if(!sendBatch(proto)) state = State.SEND_OFFER; break; } } @@ -416,11 +406,18 @@ abstract class StreamConnection implements DatabaseListener { connection.dispose(true); } - private void sendAcks(AckWriter a) throws DbException, IOException { - while(db.generateAck(contactId, a)); + private void sendAcks(ProtocolWriter proto) + throws DbException, IOException { + int maxBatches = proto.getMaxBatchesForAck(Long.MAX_VALUE); + Ack a = db.generateAck(contactId, maxBatches); + while(a != null) { + proto.writeAck(a); + a = db.generateAck(contactId, maxBatches); + } } - private boolean sendBatch(BatchWriter b) throws DbException, IOException { + private boolean sendBatch(ProtocolWriter proto) + throws DbException, IOException { Collection<MessageId> req; // Retrieve the requested message IDs synchronized(this) { @@ -429,31 +426,40 @@ abstract class StreamConnection implements DatabaseListener { req = requested; } // Try to generate a batch, updating the collection of message IDs - boolean anyAdded = db.generateBatch(contactId, b, req); - // If no more batches can be generated, discard the remaining IDs - if(!anyAdded) { + int capacity = proto.getMessageCapacityForBatch(Long.MAX_VALUE); + RawBatch b = db.generateBatch(contactId, capacity, req); + if(b == null) { + // No more batches can be generated - discard the remaining IDs synchronized(this) { assert offered == null; assert requested == req; requested = null; } + return false; + } else { + proto.writeBatch(b); + return true; } - return anyAdded; } - private boolean sendOffer(OfferWriter o) throws DbException, IOException { + private boolean sendOffer(ProtocolWriter proto) + throws DbException, IOException { // Generate an offer - Collection<MessageId> off = db.generateOffer(contactId, o); + int maxMessages = proto.getMaxMessagesForOffer(Long.MAX_VALUE); + Offer o = db.generateOffer(contactId, maxMessages); + if(o == null) return false; + proto.writeOffer(o); // Store the offered message IDs synchronized(this) { assert offered == null; assert requested == null; - offered = off; + offered = o.getMessageIds(); } - return !off.isEmpty(); + return true; } - private void sendRequest(RequestWriter r) throws DbException, IOException { + private void sendRequest(ProtocolWriter proto) + throws DbException, IOException { Offer o; // Retrieve the incoming offer synchronized(this) { @@ -462,16 +468,19 @@ abstract class StreamConnection implements DatabaseListener { incomingOffer = null; } // Process the offer and generate a request - db.receiveOffer(contactId, o, r); + Request r = db.receiveOffer(contactId, o); + proto.writeRequest(r); } - private void sendTransportUpdate(TransportUpdateWriter t) + private void sendTransportUpdate(ProtocolWriter proto) throws DbException, IOException { - db.generateTransportUpdate(contactId, t); + TransportUpdate t = db.generateTransportUpdate(contactId); + if(t != null) proto.writeTransportUpdate(t); } - private void sendSubscriptionUpdate(SubscriptionUpdateWriter s) + private void sendSubscriptionUpdate(ProtocolWriter proto) throws DbException, IOException { - db.generateSubscriptionUpdate(contactId, s); + SubscriptionUpdate s = db.generateSubscriptionUpdate(contactId); + if(s != null) proto.writeSubscriptionUpdate(s); } } diff --git a/components/net/sf/briar/transport/stream/StreamConnectionFactoryImpl.java b/components/net/sf/briar/transport/stream/StreamConnectionFactoryImpl.java index 659d293aeeb73d7a4725913d27bb48959c7528b2..7217b862dd87ab3683234ca408edf3834738fea8 100644 --- a/components/net/sf/briar/transport/stream/StreamConnectionFactoryImpl.java +++ b/components/net/sf/briar/transport/stream/StreamConnectionFactoryImpl.java @@ -5,8 +5,9 @@ import java.util.concurrent.Executor; import net.sf.briar.api.ContactId; import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.protocol.ProtocolReaderFactory; +import net.sf.briar.api.protocol.ProtocolWriterFactory; import net.sf.briar.api.protocol.TransportIndex; -import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; +import net.sf.briar.api.serial.SerialComponent; import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionWriterFactory; @@ -18,31 +19,33 @@ import com.google.inject.Inject; class StreamConnectionFactoryImpl implements StreamConnectionFactory { private final Executor executor; + private final DatabaseComponent db; + private final SerialComponent serial; private final ConnectionReaderFactory connReaderFactory; private final ConnectionWriterFactory connWriterFactory; - private final DatabaseComponent db; private final ProtocolReaderFactory protoReaderFactory; private final ProtocolWriterFactory protoWriterFactory; @Inject - StreamConnectionFactoryImpl(Executor executor, - ConnectionReaderFactory connReaderFactory, - ConnectionWriterFactory connWriterFactory, DatabaseComponent db, + StreamConnectionFactoryImpl(Executor executor, DatabaseComponent db, + SerialComponent serial, ConnectionReaderFactory connReaderFactory, + ConnectionWriterFactory connWriterFactory, ProtocolReaderFactory protoReaderFactory, ProtocolWriterFactory protoWriterFactory) { this.executor = executor; + this.db = db; + this.serial = serial; this.connReaderFactory = connReaderFactory; this.connWriterFactory = connWriterFactory; - this.db = db; this.protoReaderFactory = protoReaderFactory; this.protoWriterFactory = protoWriterFactory; } public void createIncomingConnection(ConnectionContext ctx, StreamTransportConnection s, byte[] tag) { - final StreamConnection conn = new IncomingStreamConnection(executor, - connReaderFactory, connWriterFactory, db, protoReaderFactory, - protoWriterFactory, ctx, s, tag); + final StreamConnection conn = new IncomingStreamConnection(executor, db, + serial, connReaderFactory, connWriterFactory, + protoReaderFactory, protoWriterFactory, ctx, s, tag); Runnable write = new Runnable() { public void run() { conn.write(); @@ -59,9 +62,9 @@ class StreamConnectionFactoryImpl implements StreamConnectionFactory { public void createOutgoingConnection(ContactId c, TransportIndex i, StreamTransportConnection s) { - final StreamConnection conn = new OutgoingStreamConnection(executor, - connReaderFactory, connWriterFactory, db, protoReaderFactory, - protoWriterFactory, c, i, s); + final StreamConnection conn = new OutgoingStreamConnection(executor, db, + serial, connReaderFactory, connWriterFactory, + protoReaderFactory, protoWriterFactory, c, i, s); Runnable write = new Runnable() { public void run() { conn.write(); diff --git a/test/build.xml b/test/build.xml index e7df24043d059788a956d140e7a82852ad623db9..f40c9ce045ba26a020ac711414c41dfac812eded 100644 --- a/test/build.xml +++ b/test/build.xml @@ -37,11 +37,11 @@ <test name='net.sf.briar.plugins.socket.SimpleSocketPluginTest'/> <test name='net.sf.briar.protocol.AckReaderTest'/> <test name='net.sf.briar.protocol.BatchReaderTest'/> + <test name='net.sf.briar.protocol.ConstantsTest'/> <test name='net.sf.briar.protocol.ConsumersTest'/> <test name='net.sf.briar.protocol.ProtocolReadWriteTest'/> + <test name='net.sf.briar.protocol.ProtocolWriterImplTest'/> <test name='net.sf.briar.protocol.RequestReaderTest'/> - <test name='net.sf.briar.protocol.writers.ConstantsTest'/> - <test name='net.sf.briar.protocol.writers.RequestWriterImplTest'/> <test name='net.sf.briar.serial.ReaderImplTest'/> <test name='net.sf.briar.serial.WriterImplTest'/> <test name='net.sf.briar.setup.SetupWorkerTest'/> diff --git a/test/net/sf/briar/ProtocolIntegrationTest.java b/test/net/sf/briar/ProtocolIntegrationTest.java index 7d5434a62503e9b194af1da8f5fd56de42875622..ca726f7a866def51efae31b5d012db204d1a8ff8 100644 --- a/test/net/sf/briar/ProtocolIntegrationTest.java +++ b/test/net/sf/briar/ProtocolIntegrationTest.java @@ -8,6 +8,7 @@ import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.io.OutputStream; import java.security.KeyPair; +import java.util.Arrays; import java.util.BitSet; import java.util.Collection; import java.util.Collections; @@ -16,7 +17,7 @@ import java.util.LinkedHashMap; import java.util.Map; import java.util.Random; import java.util.concurrent.Executor; -import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.Executors; import junit.framework.TestCase; import net.sf.briar.api.crypto.CryptoComponent; @@ -31,21 +32,18 @@ import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.MessageFactory; import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.Offer; +import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.ProtocolReader; import net.sf.briar.api.protocol.ProtocolReaderFactory; +import net.sf.briar.api.protocol.ProtocolWriter; +import net.sf.briar.api.protocol.ProtocolWriterFactory; +import net.sf.briar.api.protocol.RawBatch; import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportUpdate; -import net.sf.briar.api.protocol.writers.AckWriter; -import net.sf.briar.api.protocol.writers.BatchWriter; -import net.sf.briar.api.protocol.writers.OfferWriter; -import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; -import net.sf.briar.api.protocol.writers.RequestWriter; -import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter; -import net.sf.briar.api.protocol.writers.TransportUpdateWriter; import net.sf.briar.api.transport.ConnectionReader; import net.sf.briar.api.transport.ConnectionReaderFactory; import net.sf.briar.api.transport.ConnectionWriter; @@ -54,7 +52,6 @@ import net.sf.briar.crypto.CryptoModule; import net.sf.briar.db.DatabaseModule; import net.sf.briar.lifecycle.LifecycleModule; import net.sf.briar.protocol.ProtocolModule; -import net.sf.briar.protocol.writers.ProtocolWritersModule; import net.sf.briar.serial.SerialModule; import net.sf.briar.transport.TransportModule; import net.sf.briar.transport.batch.TransportBatchModule; @@ -76,6 +73,7 @@ public class ProtocolIntegrationTest extends TestCase { private final ConnectionWriterFactory connectionWriterFactory; private final ProtocolReaderFactory protocolReaderFactory; private final ProtocolWriterFactory protocolWriterFactory; + private final PacketFactory packetFactory; private final CryptoComponent crypto; private final byte[] secret; private final TransportIndex transportIndex = new TransportIndex(13); @@ -93,19 +91,19 @@ public class ProtocolIntegrationTest extends TestCase { @Override public void configure() { bind(Executor.class).toInstance( - new ScheduledThreadPoolExecutor(5)); + Executors.newCachedThreadPool()); } }; Injector i = Guice.createInjector(testModule, new CryptoModule(), new DatabaseModule(), new LifecycleModule(), - new ProtocolModule(), new ProtocolWritersModule(), - new SerialModule(), new TestDatabaseModule(), - new TransportBatchModule(), new TransportModule(), - new TransportStreamModule()); + new ProtocolModule(), new SerialModule(), + new TestDatabaseModule(), new TransportBatchModule(), + new TransportModule(), new TransportStreamModule()); connectionReaderFactory = i.getInstance(ConnectionReaderFactory.class); connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class); protocolReaderFactory = i.getInstance(ProtocolReaderFactory.class); protocolWriterFactory = i.getInstance(ProtocolWriterFactory.class); + packetFactory = i.getInstance(PacketFactory.class); crypto = i.getInstance(CryptoComponent.class); // Create a shared secret Random r = new Random(); @@ -149,47 +147,51 @@ public class ProtocolIntegrationTest extends TestCase { private byte[] write() throws Exception { ByteArrayOutputStream out = new ByteArrayOutputStream(); - ConnectionWriter w = connectionWriterFactory.createConnectionWriter(out, - Long.MAX_VALUE, secret.clone()); - OutputStream out1 = w.getOutputStream(); + ConnectionWriter conn = connectionWriterFactory.createConnectionWriter( + out, Long.MAX_VALUE, secret.clone()); + OutputStream out1 = conn.getOutputStream(); + ProtocolWriter proto = protocolWriterFactory.createProtocolWriter(out1); - AckWriter a = protocolWriterFactory.createAckWriter(out1); - assertTrue(a.writeBatchId(ack)); - a.finish(); + Ack a = packetFactory.createAck(Collections.singletonList(ack)); + proto.writeAck(a); - BatchWriter b = protocolWriterFactory.createBatchWriter(out1); - assertTrue(b.writeMessage(message.getSerialised())); - assertTrue(b.writeMessage(message1.getSerialised())); - assertTrue(b.writeMessage(message2.getSerialised())); - assertTrue(b.writeMessage(message3.getSerialised())); - b.finish(); + Collection<byte[]> batch = Arrays.asList(new byte[][] { + message.getSerialised(), + message1.getSerialised(), + message2.getSerialised(), + message3.getSerialised() + }); + RawBatch b = packetFactory.createBatch(batch); + proto.writeBatch(b); - OfferWriter o = protocolWriterFactory.createOfferWriter(out1); - assertTrue(o.writeMessageId(message.getId())); - assertTrue(o.writeMessageId(message1.getId())); - assertTrue(o.writeMessageId(message2.getId())); - assertTrue(o.writeMessageId(message3.getId())); - o.finish(); + Collection<MessageId> offer = Arrays.asList(new MessageId[] { + message.getId(), + message1.getId(), + message2.getId(), + message3.getId() + }); + Offer o = packetFactory.createOffer(offer); + proto.writeOffer(o); - RequestWriter r = protocolWriterFactory.createRequestWriter(out1); BitSet requested = new BitSet(4); requested.set(1); requested.set(3); - r.writeRequest(requested, 4); + Request r = packetFactory.createRequest(requested, 4); + proto.writeRequest(r); - SubscriptionUpdateWriter s = - protocolWriterFactory.createSubscriptionUpdateWriter(out1); // Use a LinkedHashMap for predictable iteration order Map<Group, Long> subs = new LinkedHashMap<Group, Long>(); subs.put(group, 0L); subs.put(group1, 0L); - s.writeSubscriptions(subs, timestamp); + SubscriptionUpdate s = packetFactory.createSubscriptionUpdate(subs, + timestamp); + proto.writeSubscriptionUpdate(s); - TransportUpdateWriter t = - protocolWriterFactory.createTransportUpdateWriter(out1); - t.writeTransports(transports, timestamp); + TransportUpdate t = packetFactory.createTransportUpdate(transports, + timestamp); + proto.writeTransportUpdate(t); - out1.close(); + out1.flush(); return out.toByteArray(); } diff --git a/test/net/sf/briar/db/DatabaseComponentImplTest.java b/test/net/sf/briar/db/DatabaseComponentImplTest.java index 0a2c3f612c0b39b65dd90e4c5a240419a1bfbc11..ccab25ee694414696e26316fb89a8ad75bcde4e9 100644 --- a/test/net/sf/briar/db/DatabaseComponentImplTest.java +++ b/test/net/sf/briar/db/DatabaseComponentImplTest.java @@ -8,6 +8,7 @@ import java.util.Collections; import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.DbException; import net.sf.briar.api.lifecycle.ShutdownManager; +import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.db.DatabaseCleaner.Callback; import org.jmock.Expectations; @@ -27,11 +28,13 @@ public class DatabaseComponentImplTest extends DatabaseComponentTest { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); context.checking(new Expectations() {{ oneOf(database).getFreeSpace(); will(returnValue(MIN_FREE_SPACE)); }}); - Callback db = createDatabaseComponentImpl(database, cleaner, shutdown); + Callback db = createDatabaseComponentImpl(database, cleaner, shutdown, + packetFactory); db.checkFreeSpaceAndClean(); @@ -45,6 +48,7 @@ public class DatabaseComponentImplTest extends DatabaseComponentTest { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); context.checking(new Expectations() {{ oneOf(database).getFreeSpace(); will(returnValue(MIN_FREE_SPACE - 1)); @@ -57,7 +61,8 @@ public class DatabaseComponentImplTest extends DatabaseComponentTest { oneOf(database).getFreeSpace(); will(returnValue(MIN_FREE_SPACE)); }}); - Callback db = createDatabaseComponentImpl(database, cleaner, shutdown); + Callback db = createDatabaseComponentImpl(database, cleaner, shutdown, + packetFactory); db.checkFreeSpaceAndClean(); @@ -72,6 +77,7 @@ public class DatabaseComponentImplTest extends DatabaseComponentTest { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); context.checking(new Expectations() {{ oneOf(database).getFreeSpace(); will(returnValue(MIN_FREE_SPACE - 1)); @@ -86,7 +92,8 @@ public class DatabaseComponentImplTest extends DatabaseComponentTest { oneOf(database).getFreeSpace(); will(returnValue(MIN_FREE_SPACE)); }}); - Callback db = createDatabaseComponentImpl(database, cleaner, shutdown); + Callback db = createDatabaseComponentImpl(database, cleaner, shutdown, + packetFactory); db.checkFreeSpaceAndClean(); @@ -101,6 +108,7 @@ public class DatabaseComponentImplTest extends DatabaseComponentTest { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); context.checking(new Expectations() {{ oneOf(database).getFreeSpace(); will(returnValue(MIN_FREE_SPACE - 1)); @@ -117,7 +125,8 @@ public class DatabaseComponentImplTest extends DatabaseComponentTest { oneOf(database).getFreeSpace(); will(returnValue(MIN_FREE_SPACE)); }}); - Callback db = createDatabaseComponentImpl(database, cleaner, shutdown); + Callback db = createDatabaseComponentImpl(database, cleaner, shutdown, + packetFactory); db.checkFreeSpaceAndClean(); @@ -127,13 +136,15 @@ public class DatabaseComponentImplTest extends DatabaseComponentTest { @Override protected <T> DatabaseComponent createDatabaseComponent( Database<T> database, DatabaseCleaner cleaner, - ShutdownManager shutdown) { - return createDatabaseComponentImpl(database, cleaner, shutdown); + ShutdownManager shutdown, PacketFactory packetFactory) { + return createDatabaseComponentImpl(database, cleaner, shutdown, + packetFactory); } private <T> DatabaseComponentImpl<T> createDatabaseComponentImpl( Database<T> database, DatabaseCleaner cleaner, - ShutdownManager shutdown) { - return new DatabaseComponentImpl<T>(database, cleaner, shutdown); + ShutdownManager shutdown, PacketFactory packetFactory) { + return new DatabaseComponentImpl<T>(database, cleaner, shutdown, + packetFactory); } } diff --git a/test/net/sf/briar/db/DatabaseComponentTest.java b/test/net/sf/briar/db/DatabaseComponentTest.java index e17b18ed1aebcf201e7c01be909ce57a923b0151..b196cfcd549cc39cd68d8f447fd0bc9e3e3eb64d 100644 --- a/test/net/sf/briar/db/DatabaseComponentTest.java +++ b/test/net/sf/briar/db/DatabaseComponentTest.java @@ -1,6 +1,7 @@ package net.sf.briar.db; import java.util.ArrayList; +import java.util.Arrays; import java.util.BitSet; import java.util.Collection; import java.util.Collections; @@ -32,18 +33,14 @@ import net.sf.briar.api.protocol.GroupId; import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.MessageId; import net.sf.briar.api.protocol.Offer; -import net.sf.briar.api.protocol.ProtocolConstants; +import net.sf.briar.api.protocol.PacketFactory; +import net.sf.briar.api.protocol.RawBatch; +import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportUpdate; -import net.sf.briar.api.protocol.writers.AckWriter; -import net.sf.briar.api.protocol.writers.BatchWriter; -import net.sf.briar.api.protocol.writers.OfferWriter; -import net.sf.briar.api.protocol.writers.RequestWriter; -import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter; -import net.sf.briar.api.protocol.writers.TransportUpdateWriter; import net.sf.briar.api.transport.ConnectionWindow; import org.jmock.Expectations; @@ -105,7 +102,7 @@ public abstract class DatabaseComponentTest extends TestCase { protected abstract <T> DatabaseComponent createDatabaseComponent( Database<T> database, DatabaseCleaner cleaner, - ShutdownManager shutdown); + ShutdownManager shutdown, PacketFactory packetFactory); @Test @SuppressWarnings("unchecked") @@ -115,6 +112,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final ConnectionWindow connectionWindow = context.mock(ConnectionWindow.class); final Group group = context.mock(Group.class); @@ -200,7 +198,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).close(); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.open(false); db.addListener(listener); @@ -233,6 +231,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); context.checking(new Expectations() {{ // setRating(authorId, Rating.GOOD) allowing(database).startTransaction(); @@ -251,7 +250,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).commitTransaction(txn); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.setRating(authorId, Rating.GOOD); @@ -265,6 +264,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); context.checking(new Expectations() {{ // setRating(authorId, Rating.GOOD) oneOf(database).startTransaction(); @@ -287,7 +287,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).commitTransaction(txn); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.setRating(authorId, Rating.GOOD); @@ -302,6 +302,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); context.checking(new Expectations() {{ // setRating(authorId, Rating.GOOD) oneOf(database).startTransaction(); @@ -327,7 +328,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).commitTransaction(txn); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.setRating(authorId, Rating.GOOD); @@ -342,6 +343,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); context.checking(new Expectations() {{ // addLocalGroupMessage(message) oneOf(database).startTransaction(); @@ -351,7 +353,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).commitTransaction(txn); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.addLocalGroupMessage(message); @@ -365,6 +367,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); context.checking(new Expectations() {{ // addLocalGroupMessage(message) oneOf(database).startTransaction(); @@ -376,7 +379,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).commitTransaction(txn); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.addLocalGroupMessage(message); @@ -390,6 +393,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); context.checking(new Expectations() {{ // addLocalGroupMessage(message) oneOf(database).startTransaction(); @@ -410,7 +414,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).commitTransaction(txn); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.addLocalGroupMessage(message); @@ -425,6 +429,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); context.checking(new Expectations() {{ // addLocalGroupMessage(message) oneOf(database).startTransaction(); @@ -448,7 +453,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).commitTransaction(txn); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.addLocalGroupMessage(message); @@ -462,6 +467,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); will(returnValue(txn)); @@ -473,7 +479,7 @@ public abstract class DatabaseComponentTest extends TestCase { will(returnValue(false)); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.addLocalPrivateMessage(privateMessage, contactId); @@ -487,6 +493,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); will(returnValue(txn)); @@ -499,7 +506,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).setStatus(txn, contactId, messageId, Status.NEW); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.addLocalPrivateMessage(privateMessage, contactId); @@ -514,17 +521,10 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); - final AckWriter ackWriter = context.mock(AckWriter.class); - final BatchWriter batchWriter = context.mock(BatchWriter.class); - final OfferWriter offerWriter = context.mock(OfferWriter.class); - final SubscriptionUpdateWriter subscriptionUpdateWriter = - context.mock(SubscriptionUpdateWriter.class); - final TransportUpdateWriter transportUpdateWriter = - context.mock(TransportUpdateWriter.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final Ack ack = context.mock(Ack.class); final Batch batch = context.mock(Batch.class); final Offer offer = context.mock(Offer.class); - final RequestWriter requestWriter = context.mock(RequestWriter.class); final SubscriptionUpdate subscriptionUpdate = context.mock(SubscriptionUpdate.class); final TransportUpdate transportUpdate = @@ -538,7 +538,7 @@ public abstract class DatabaseComponentTest extends TestCase { exactly(19).of(database).commitTransaction(txn); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); try { db.addLocalPrivateMessage(privateMessage, contactId); @@ -546,33 +546,33 @@ public abstract class DatabaseComponentTest extends TestCase { } catch(NoSuchContactException expected) {} try { - db.generateAck(contactId, ackWriter); + db.generateAck(contactId, 123); fail(); } catch(NoSuchContactException expected) {} try { - db.generateBatch(contactId, batchWriter); + db.generateBatch(contactId, 123); fail(); } catch(NoSuchContactException expected) {} try { - db.generateBatch(contactId, batchWriter, + db.generateBatch(contactId, 123, Collections.<MessageId>emptyList()); fail(); } catch(NoSuchContactException expected) {} try { - db.generateOffer(contactId, offerWriter); + db.generateOffer(contactId, 123); fail(); } catch(NoSuchContactException expected) {} try { - db.generateSubscriptionUpdate(contactId, subscriptionUpdateWriter); + db.generateSubscriptionUpdate(contactId); fail(); } catch(NoSuchContactException expected) {} try { - db.generateTransportUpdate(contactId, transportUpdateWriter); + db.generateTransportUpdate(contactId); fail(); } catch(NoSuchContactException expected) {} @@ -607,7 +607,7 @@ public abstract class DatabaseComponentTest extends TestCase { } catch(NoSuchContactException expected) {} try { - db.receiveOffer(contactId, offer, requestWriter); + db.receiveOffer(contactId, offer); fail(); } catch(NoSuchContactException expected) {} @@ -650,7 +650,8 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); - final AckWriter ackWriter = context.mock(AckWriter.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); + final Ack ack = context.mock(Ack.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); will(returnValue(txn)); @@ -658,22 +659,18 @@ public abstract class DatabaseComponentTest extends TestCase { allowing(database).containsContact(txn, contactId); will(returnValue(true)); // Get the batches to ack - oneOf(database).getBatchesToAck(txn, contactId); + oneOf(database).getBatchesToAck(txn, contactId, 123); will(returnValue(batchesToAck)); - // Try to add both batches to the writer - only manage to add one - oneOf(ackWriter).writeBatchId(batchId); - will(returnValue(true)); - oneOf(ackWriter).writeBatchId(batchId1); - will(returnValue(false)); - oneOf(ackWriter).finish(); - // Record the batch that was acked - oneOf(database).removeBatchesToAck(txn, contactId, - Collections.singletonList(batchId)); + // Create the packet + oneOf(packetFactory).createAck(batchesToAck); + will(returnValue(ack)); + // Record the batches that were acked + oneOf(database).removeBatchesToAck(txn, contactId, batchesToAck); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); - db.generateAck(contactId, ackWriter); + assertEquals(ack, db.generateAck(contactId, 123)); context.assertIsSatisfied(); } @@ -682,47 +679,47 @@ public abstract class DatabaseComponentTest extends TestCase { public void testGenerateBatch() throws Exception { final MessageId messageId1 = new MessageId(TestUtils.getRandomId()); final byte[] raw1 = new byte[size]; - final Collection<MessageId> sendable = new ArrayList<MessageId>(); - sendable.add(messageId); - sendable.add(messageId1); + final Collection<MessageId> sendable = Arrays.asList(new MessageId[] { + messageId, + messageId1 + }); + final Collection<byte[]> messages = Arrays.asList(new byte[][] { + raw, + raw1 + }); Mockery context = new Mockery(); @SuppressWarnings("unchecked") final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); - final BatchWriter batchWriter = context.mock(BatchWriter.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); + final RawBatch batch = context.mock(RawBatch.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); will(returnValue(txn)); allowing(database).commitTransaction(txn); allowing(database).containsContact(txn, contactId); will(returnValue(true)); - // Find out how much space we've got - oneOf(batchWriter).getCapacity(); - will(returnValue(ProtocolConstants.MAX_PACKET_LENGTH)); // Get the sendable messages - oneOf(database).getSendableMessages(txn, contactId, - ProtocolConstants.MAX_PACKET_LENGTH); + oneOf(database).getSendableMessages(txn, contactId, size * 2); will(returnValue(sendable)); oneOf(database).getMessage(txn, messageId); will(returnValue(raw)); oneOf(database).getMessage(txn, messageId1); will(returnValue(raw1)); - // Add the sendable messages to the batch - oneOf(batchWriter).writeMessage(raw); - will(returnValue(true)); - oneOf(batchWriter).writeMessage(raw1); - will(returnValue(true)); - oneOf(batchWriter).finish(); + // Create the packet + oneOf(packetFactory).createBatch(messages); + will(returnValue(batch)); + // Record the outstanding batch + oneOf(batch).getId(); will(returnValue(batchId)); - // Record the message that was sent oneOf(database).addOutstandingBatch(txn, contactId, batchId, sendable); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); - db.generateBatch(contactId, batchWriter); + assertEquals(batch, db.generateBatch(contactId, size * 2)); context.assertIsSatisfied(); } @@ -736,21 +733,22 @@ public abstract class DatabaseComponentTest extends TestCase { requested.add(messageId); requested.add(messageId1); requested.add(messageId2); + final Collection<byte[]> msgs = Arrays.asList(new byte[][] { + raw1 + }); Mockery context = new Mockery(); @SuppressWarnings("unchecked") final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); - final BatchWriter batchWriter = context.mock(BatchWriter.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); + final RawBatch batch = context.mock(RawBatch.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); will(returnValue(txn)); allowing(database).commitTransaction(txn); allowing(database).containsContact(txn, contactId); will(returnValue(true)); - // Find out how much space we've got - oneOf(batchWriter).getCapacity(); - will(returnValue(ProtocolConstants.MAX_PACKET_LENGTH)); // Try to get the requested messages oneOf(database).getMessageIfSendable(txn, contactId, messageId); will(returnValue(null)); // Message is not sendable @@ -758,19 +756,19 @@ public abstract class DatabaseComponentTest extends TestCase { will(returnValue(raw1)); // Message is sendable oneOf(database).getMessageIfSendable(txn, contactId, messageId2); will(returnValue(null)); // Message is not sendable - // Add the sendable message to the batch - oneOf(batchWriter).writeMessage(raw1); - will(returnValue(true)); - oneOf(batchWriter).finish(); + // Create the packet + oneOf(packetFactory).createBatch(msgs); + will(returnValue(batch)); + // Record the outstanding batch + oneOf(batch).getId(); will(returnValue(batchId)); - // Record the message that was sent oneOf(database).addOutstandingBatch(txn, contactId, batchId, Collections.singletonList(messageId1)); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); - db.generateBatch(contactId, batchWriter, requested); + assertEquals(batch, db.generateBatch(contactId, size * 3, requested)); context.assertIsSatisfied(); } @@ -778,15 +776,16 @@ public abstract class DatabaseComponentTest extends TestCase { @Test public void testGenerateOffer() throws Exception { final MessageId messageId1 = new MessageId(TestUtils.getRandomId()); - final Collection<MessageId> sendable = new ArrayList<MessageId>(); - sendable.add(messageId); - sendable.add(messageId1); + final Collection<MessageId> offerable = new ArrayList<MessageId>(); + offerable.add(messageId); + offerable.add(messageId1); Mockery context = new Mockery(); @SuppressWarnings("unchecked") final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); - final OfferWriter offerWriter = context.mock(OfferWriter.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); + final Offer offer = context.mock(Offer.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); will(returnValue(txn)); @@ -794,20 +793,16 @@ public abstract class DatabaseComponentTest extends TestCase { allowing(database).containsContact(txn, contactId); will(returnValue(true)); // Get the sendable message IDs - oneOf(database).getSendableMessages(txn, contactId); - will(returnValue(sendable)); - // Try to add both IDs to the writer - only manage to add one - oneOf(offerWriter).writeMessageId(messageId); - will(returnValue(true)); - oneOf(offerWriter).writeMessageId(messageId1); - will(returnValue(false)); - oneOf(offerWriter).finish(); + oneOf(database).getOfferableMessages(txn, contactId, 123); + will(returnValue(offerable)); + // Create the packet + oneOf(packetFactory).createOffer(offerable); + will(returnValue(offer)); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); - assertEquals(Collections.singletonList(messageId), - db.generateOffer(contactId, offerWriter)); + assertEquals(offer, db.generateOffer(contactId, 123)); context.assertIsSatisfied(); } @@ -820,8 +815,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); - final SubscriptionUpdateWriter subscriptionUpdateWriter = - context.mock(SubscriptionUpdateWriter.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); will(returnValue(txn)); @@ -835,26 +829,23 @@ public abstract class DatabaseComponentTest extends TestCase { will(returnValue(now)); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); - db.generateSubscriptionUpdate(contactId, subscriptionUpdateWriter); + assertNull(db.generateSubscriptionUpdate(contactId)); context.assertIsSatisfied(); } @Test public void testGenerateSubscriptionUpdate() throws Exception { - final MessageId messageId1 = new MessageId(TestUtils.getRandomId()); - final Collection<MessageId> sendable = new ArrayList<MessageId>(); - sendable.add(messageId); - sendable.add(messageId1); Mockery context = new Mockery(); @SuppressWarnings("unchecked") final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); - final SubscriptionUpdateWriter subscriptionUpdateWriter = - context.mock(SubscriptionUpdateWriter.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); + final SubscriptionUpdate subscriptionUpdate = + context.mock(SubscriptionUpdate.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); will(returnValue(txn)); @@ -871,15 +862,17 @@ public abstract class DatabaseComponentTest extends TestCase { will(returnValue(Collections.singletonMap(group, 0L))); oneOf(database).setSubscriptionsSent(with(txn), with(contactId), with(any(long.class))); - // Add the subscriptions to the writer - oneOf(subscriptionUpdateWriter).writeSubscriptions( + // Create the packet + oneOf(packetFactory).createSubscriptionUpdate( with(Collections.singletonMap(group, 0L)), with(any(long.class))); + will(returnValue(subscriptionUpdate)); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); - db.generateSubscriptionUpdate(contactId, subscriptionUpdateWriter); + assertEquals(subscriptionUpdate, + db.generateSubscriptionUpdate(contactId)); context.assertIsSatisfied(); } @@ -892,8 +885,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); - final TransportUpdateWriter transportUpdateWriter = - context.mock(TransportUpdateWriter.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); will(returnValue(txn)); @@ -907,26 +899,23 @@ public abstract class DatabaseComponentTest extends TestCase { will(returnValue(now)); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); - db.generateTransportUpdate(contactId, transportUpdateWriter); + assertNull(db.generateTransportUpdate(contactId)); context.assertIsSatisfied(); } @Test public void testGenerateTransportUpdate() throws Exception { - final MessageId messageId1 = new MessageId(TestUtils.getRandomId()); - final Collection<MessageId> sendable = new ArrayList<MessageId>(); - sendable.add(messageId); - sendable.add(messageId1); Mockery context = new Mockery(); @SuppressWarnings("unchecked") final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); - final TransportUpdateWriter transportUpdateWriter = - context.mock(TransportUpdateWriter.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); + final TransportUpdate transportUpdate = + context.mock(TransportUpdate.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); will(returnValue(txn)); @@ -943,14 +932,15 @@ public abstract class DatabaseComponentTest extends TestCase { will(returnValue(transports)); oneOf(database).setTransportsSent(with(txn), with(contactId), with(any(long.class))); - // Add the properties to the writer - oneOf(transportUpdateWriter).writeTransports(with(transports), + // Create the packet + oneOf(packetFactory).createTransportUpdate(with(transports), with(any(long.class))); + will(returnValue(transportUpdate)); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); - db.generateTransportUpdate(contactId, transportUpdateWriter); + assertEquals(transportUpdate, db.generateTransportUpdate(contactId)); context.assertIsSatisfied(); } @@ -963,6 +953,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final Ack ack = context.mock(Ack.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); @@ -980,7 +971,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).removeLostBatch(txn, contactId, batchId1); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.receiveAck(contactId, ack); @@ -994,6 +985,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final Batch batch = context.mock(Batch.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); @@ -1013,7 +1005,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).addBatchToAck(txn, contactId, batchId); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.receiveBatch(contactId, batch); @@ -1027,6 +1019,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final Batch batch = context.mock(Batch.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); @@ -1045,7 +1038,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).addBatchToAck(txn, contactId, batchId); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.receiveBatch(contactId, batch); @@ -1060,6 +1053,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final Batch batch = context.mock(Batch.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); @@ -1079,7 +1073,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).addBatchToAck(txn, contactId, batchId); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.receiveBatch(contactId, batch); @@ -1094,6 +1088,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final Batch batch = context.mock(Batch.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); @@ -1117,7 +1112,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).addBatchToAck(txn, contactId, batchId); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.receiveBatch(contactId, batch); @@ -1131,6 +1126,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final Batch batch = context.mock(Batch.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); @@ -1163,7 +1159,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).addBatchToAck(txn, contactId, batchId); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.receiveBatch(contactId, batch); @@ -1177,6 +1173,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final Batch batch = context.mock(Batch.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); @@ -1211,7 +1208,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).addBatchToAck(txn, contactId, batchId); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.receiveBatch(contactId, batch); @@ -1234,8 +1231,9 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final Offer offer = context.mock(Offer.class); - final RequestWriter requestWriter = context.mock(RequestWriter.class); + final Request request = context.mock(Request.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); will(returnValue(txn)); @@ -1251,12 +1249,14 @@ public abstract class DatabaseComponentTest extends TestCase { will(returnValue(true)); // Visible - do not request message # 1 oneOf(database).setStatusSeenIfVisible(txn, contactId, messageId2); will(returnValue(false)); // Not visible - request message # 2 - oneOf(requestWriter).writeRequest(expectedRequest, 3); + // Create the packet + oneOf(packetFactory).createRequest(expectedRequest, 3); + will(returnValue(request)); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); - db.receiveOffer(contactId, offer, requestWriter); + assertEquals(request, db.receiveOffer(contactId, offer)); context.assertIsSatisfied(); } @@ -1269,6 +1269,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final SubscriptionUpdate subscriptionUpdate = context.mock(SubscriptionUpdate.class); context.checking(new Expectations() {{ @@ -1286,7 +1287,7 @@ public abstract class DatabaseComponentTest extends TestCase { Collections.singletonMap(group, 0L), timestamp); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.receiveSubscriptionUpdate(contactId, subscriptionUpdate); @@ -1301,6 +1302,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final TransportUpdate transportUpdate = context.mock(TransportUpdate.class); context.checking(new Expectations() {{ @@ -1318,7 +1320,7 @@ public abstract class DatabaseComponentTest extends TestCase { timestamp); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.receiveTransportUpdate(contactId, transportUpdate); @@ -1332,6 +1334,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final DatabaseListener listener = context.mock(DatabaseListener.class); context.checking(new Expectations() {{ // addLocalGroupMessage(message) @@ -1354,7 +1357,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(listener).eventOccurred(with(any(MessagesAddedEvent.class))); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.addListener(listener); db.addLocalGroupMessage(message); @@ -1369,6 +1372,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final DatabaseListener listener = context.mock(DatabaseListener.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); @@ -1384,7 +1388,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(listener).eventOccurred(with(any(MessagesAddedEvent.class))); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.addListener(listener); db.addLocalPrivateMessage(privateMessage, contactId); @@ -1400,6 +1404,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final DatabaseListener listener = context.mock(DatabaseListener.class); context.checking(new Expectations() {{ // addLocalGroupMessage(message) @@ -1413,7 +1418,7 @@ public abstract class DatabaseComponentTest extends TestCase { // The message was not added, so the listener should not be called }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.addListener(listener); db.addLocalGroupMessage(message); @@ -1429,6 +1434,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final DatabaseListener listener = context.mock(DatabaseListener.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); @@ -1442,7 +1448,7 @@ public abstract class DatabaseComponentTest extends TestCase { // The message was not added, so the listener should not be called }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.addListener(listener); db.addLocalPrivateMessage(privateMessage, contactId); @@ -1460,6 +1466,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final DatabaseListener listener = context.mock(DatabaseListener.class); context.checking(new Expectations() {{ oneOf(database).startTransaction(); @@ -1474,7 +1481,7 @@ public abstract class DatabaseComponentTest extends TestCase { TransportAddedEvent.class))); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.addListener(listener); db.setLocalProperties(transportId, properties); @@ -1492,6 +1499,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); final DatabaseListener listener = context.mock(DatabaseListener.class); context.checking(new Expectations() {{ oneOf(database).startTransaction(); @@ -1501,7 +1509,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).commitTransaction(txn); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.addListener(listener); db.setLocalProperties(transportId, properties); @@ -1516,6 +1524,7 @@ public abstract class DatabaseComponentTest extends TestCase { final Database<Object> database = context.mock(Database.class); final DatabaseCleaner cleaner = context.mock(DatabaseCleaner.class); final ShutdownManager shutdown = context.mock(ShutdownManager.class); + final PacketFactory packetFactory = context.mock(PacketFactory.class); context.checking(new Expectations() {{ allowing(database).startTransaction(); will(returnValue(txn)); @@ -1526,7 +1535,7 @@ public abstract class DatabaseComponentTest extends TestCase { oneOf(database).setStatusSeenIfVisible(txn, contactId, messageId); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, - shutdown); + shutdown, packetFactory); db.setSeen(contactId, Collections.singletonList(messageId)); diff --git a/test/net/sf/briar/db/H2DatabaseTest.java b/test/net/sf/briar/db/H2DatabaseTest.java index 728346e155cadf68f3e4fe54aef9c8f449cc1cf7..5dbaea4a212dd72ca1e66e771bb5685cd13b672b 100644 --- a/test/net/sf/briar/db/H2DatabaseTest.java +++ b/test/net/sf/briar/db/H2DatabaseTest.java @@ -46,7 +46,6 @@ import net.sf.briar.api.transport.ConnectionWindowFactory; import net.sf.briar.crypto.CryptoModule; import net.sf.briar.lifecycle.LifecycleModule; import net.sf.briar.protocol.ProtocolModule; -import net.sf.briar.protocol.writers.ProtocolWritersModule; import net.sf.briar.serial.SerialModule; import net.sf.briar.transport.TransportModule; import net.sf.briar.transport.batch.TransportBatchModule; @@ -107,10 +106,9 @@ public class H2DatabaseTest extends TestCase { }; Injector i = Guice.createInjector(testModule, new CryptoModule(), new DatabaseModule(), new LifecycleModule(), - new ProtocolModule(), new ProtocolWritersModule(), - new SerialModule(), new TransportBatchModule(), - new TransportModule(), new TransportStreamModule(), - new TestDatabaseModule(testDir)); + new ProtocolModule(), new SerialModule(), + new TransportBatchModule(), new TransportModule(), + new TransportStreamModule(), new TestDatabaseModule(testDir)); connectionContextFactory = i.getInstance(ConnectionContextFactory.class); connectionWindowFactory = i.getInstance(ConnectionWindowFactory.class); @@ -588,7 +586,7 @@ public class H2DatabaseTest extends TestCase { db.addBatchToAck(txn, contactId, batchId1); // Both batch IDs should be returned - Collection<BatchId> acks = db.getBatchesToAck(txn, contactId); + Collection<BatchId> acks = db.getBatchesToAck(txn, contactId, 1234); assertEquals(2, acks.size()); assertTrue(acks.contains(batchId)); assertTrue(acks.contains(batchId1)); @@ -597,7 +595,7 @@ public class H2DatabaseTest extends TestCase { db.removeBatchesToAck(txn, contactId, acks); // Both batch IDs should have been removed - acks = db.getBatchesToAck(txn, contactId); + acks = db.getBatchesToAck(txn, contactId, 1234); assertEquals(0, acks.size()); db.commitTransaction(txn); @@ -615,7 +613,7 @@ public class H2DatabaseTest extends TestCase { db.addBatchToAck(txn, contactId, batchId); // The batch ID should only be returned once - Collection<BatchId> acks = db.getBatchesToAck(txn, contactId); + Collection<BatchId> acks = db.getBatchesToAck(txn, contactId, 1234); assertEquals(1, acks.size()); assertTrue(acks.contains(batchId)); @@ -623,7 +621,7 @@ public class H2DatabaseTest extends TestCase { db.removeBatchesToAck(txn, contactId, acks); // The batch ID should have been removed - acks = db.getBatchesToAck(txn, contactId); + acks = db.getBatchesToAck(txn, contactId, 1234); assertEquals(0, acks.size()); db.commitTransaction(txn); diff --git a/test/net/sf/briar/db/TestMessage.java b/test/net/sf/briar/db/TestMessage.java index 42f54e06fdb4d5f86af110926b13676c05bfc236..4dbe7584a5b6296beafa3ccc47d8352ed6e869ea 100644 --- a/test/net/sf/briar/db/TestMessage.java +++ b/test/net/sf/briar/db/TestMessage.java @@ -61,10 +61,6 @@ class TestMessage implements Message { return timestamp; } - public int getLength() { - return raw.length; - } - public byte[] getSerialised() { return raw; } diff --git a/test/net/sf/briar/protocol/AckReaderTest.java b/test/net/sf/briar/protocol/AckReaderTest.java index 3d22dd7c21ec0596d7cd08668f47b7c04ed234f0..54017dff93367cf9720cd9b7975bceb64a2a205e 100644 --- a/test/net/sf/briar/protocol/AckReaderTest.java +++ b/test/net/sf/briar/protocol/AckReaderTest.java @@ -10,6 +10,7 @@ import junit.framework.TestCase; import net.sf.briar.api.FormatException; import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.BatchId; +import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Types; import net.sf.briar.api.protocol.UniqueId; @@ -42,8 +43,8 @@ public class AckReaderTest extends TestCase { @Test public void testFormatExceptionIfAckIsTooLarge() throws Exception { - AckFactory ackFactory = context.mock(AckFactory.class); - AckReader ackReader = new AckReader(ackFactory); + PacketFactory packetFactory = context.mock(PacketFactory.class); + AckReader ackReader = new AckReader(packetFactory); byte[] b = createAck(true); ByteArrayInputStream in = new ByteArrayInputStream(b); @@ -60,11 +61,11 @@ public class AckReaderTest extends TestCase { @Test @SuppressWarnings("unchecked") public void testNoFormatExceptionIfAckIsMaximumSize() throws Exception { - final AckFactory ackFactory = context.mock(AckFactory.class); - AckReader ackReader = new AckReader(ackFactory); + final PacketFactory packetFactory = context.mock(PacketFactory.class); + AckReader ackReader = new AckReader(packetFactory); final Ack ack = context.mock(Ack.class); context.checking(new Expectations() {{ - oneOf(ackFactory).createAck(with(any(Collection.class))); + oneOf(packetFactory).createAck(with(any(Collection.class))); will(returnValue(ack)); }}); @@ -79,11 +80,11 @@ public class AckReaderTest extends TestCase { @Test public void testEmptyAck() throws Exception { - final AckFactory ackFactory = context.mock(AckFactory.class); - AckReader ackReader = new AckReader(ackFactory); + final PacketFactory packetFactory = context.mock(PacketFactory.class); + AckReader ackReader = new AckReader(packetFactory); final Ack ack = context.mock(Ack.class); context.checking(new Expectations() {{ - oneOf(ackFactory).createAck( + oneOf(packetFactory).createAck( with(Collections.<BatchId>emptyList())); will(returnValue(ack)); }}); diff --git a/test/net/sf/briar/protocol/writers/ConstantsTest.java b/test/net/sf/briar/protocol/ConstantsTest.java similarity index 72% rename from test/net/sf/briar/protocol/writers/ConstantsTest.java rename to test/net/sf/briar/protocol/ConstantsTest.java index c2b82c5edaa6b1dc54d0544df3b70b6774bf8ccb..14b5e3ca565a198081c097e4f7d9c51514f2a24e 100644 --- a/test/net/sf/briar/protocol/writers/ConstantsTest.java +++ b/test/net/sf/briar/protocol/ConstantsTest.java @@ -1,4 +1,4 @@ -package net.sf.briar.protocol.writers; +package net.sf.briar.protocol; import static net.sf.briar.api.protocol.ProtocolConstants.MAX_AUTHOR_NAME_LENGTH; import static net.sf.briar.api.protocol.ProtocolConstants.MAX_BODY_LENGTH; @@ -15,12 +15,14 @@ import java.io.ByteArrayOutputStream; import java.security.PrivateKey; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import junit.framework.TestCase; import net.sf.briar.TestUtils; import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.api.protocol.Ack; import net.sf.briar.api.protocol.Author; import net.sf.briar.api.protocol.AuthorFactory; import net.sf.briar.api.protocol.BatchId; @@ -29,19 +31,18 @@ import net.sf.briar.api.protocol.GroupFactory; import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.MessageFactory; import net.sf.briar.api.protocol.MessageId; +import net.sf.briar.api.protocol.Offer; +import net.sf.briar.api.protocol.PacketFactory; +import net.sf.briar.api.protocol.ProtocolWriter; +import net.sf.briar.api.protocol.ProtocolWriterFactory; +import net.sf.briar.api.protocol.RawBatch; +import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportIndex; +import net.sf.briar.api.protocol.TransportUpdate; import net.sf.briar.api.protocol.UniqueId; -import net.sf.briar.api.protocol.writers.AckWriter; -import net.sf.briar.api.protocol.writers.BatchWriter; -import net.sf.briar.api.protocol.writers.OfferWriter; -import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter; -import net.sf.briar.api.protocol.writers.TransportUpdateWriter; -import net.sf.briar.api.serial.SerialComponent; -import net.sf.briar.api.serial.WriterFactory; import net.sf.briar.crypto.CryptoModule; -import net.sf.briar.protocol.ProtocolModule; import net.sf.briar.serial.SerialModule; import org.junit.Test; @@ -51,23 +52,23 @@ import com.google.inject.Injector; public class ConstantsTest extends TestCase { - private final WriterFactory writerFactory; private final CryptoComponent crypto; - private final SerialComponent serial; private final GroupFactory groupFactory; private final AuthorFactory authorFactory; private final MessageFactory messageFactory; + private final PacketFactory packetFactory; + private final ProtocolWriterFactory protocolWriterFactory; public ConstantsTest() throws Exception { super(); Injector i = Guice.createInjector(new CryptoModule(), new ProtocolModule(), new SerialModule()); - writerFactory = i.getInstance(WriterFactory.class); crypto = i.getInstance(CryptoComponent.class); - serial = i.getInstance(SerialComponent.class); groupFactory = i.getInstance(GroupFactory.class); authorFactory = i.getInstance(AuthorFactory.class); messageFactory = i.getInstance(MessageFactory.class); + packetFactory = i.getInstance(PacketFactory.class); + protocolWriterFactory = i.getInstance(ProtocolWriterFactory.class); } @Test @@ -83,25 +84,18 @@ public class ConstantsTest extends TestCase { private void testBatchesFitIntoAck(int length) throws Exception { // Create an ack with as many batch IDs as possible ByteArrayOutputStream out = new ByteArrayOutputStream(length); - AckWriter a = new AckWriterImpl(out, serial, writerFactory); - a.setMaxPacketLength(length); - while(a.writeBatchId(new BatchId(TestUtils.getRandomId()))); - a.finish(); + ProtocolWriter writer = protocolWriterFactory.createProtocolWriter(out); + int maxBatches = writer.getMaxBatchesForAck(length); + Collection<BatchId> acked = new ArrayList<BatchId>(); + for(int i = 0; i < maxBatches; i++) { + acked.add(new BatchId(TestUtils.getRandomId())); + } + Ack a = packetFactory.createAck(acked); + writer.writeAck(a); // Check the size of the serialised ack assertTrue(out.size() <= length); } - @Test - public void testEmptyAck() throws Exception { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - AckWriter a = new AckWriterImpl(out, serial, writerFactory); - // There's not enough room for a batch ID - a.setMaxPacketLength(4); - assertFalse(a.writeBatchId(new BatchId(TestUtils.getRandomId()))); - // Check that nothing was written - assertEquals(0, out.size()); - } - @Test public void testMessageFitsIntoBatch() throws Exception { // Create a maximum-length group @@ -122,10 +116,10 @@ public class ConstantsTest extends TestCase { // Add the message to a batch ByteArrayOutputStream out = new ByteArrayOutputStream(MAX_PACKET_LENGTH); - BatchWriter b = new BatchWriterImpl(out, serial, writerFactory, - crypto.getMessageDigest()); - assertTrue(b.writeMessage(message.getSerialised())); - b.finish(); + ProtocolWriter writer = protocolWriterFactory.createProtocolWriter(out); + RawBatch b = packetFactory.createBatch(Collections.singletonList( + message.getSerialised())); + writer.writeBatch(b); // Check the size of the serialised batch assertTrue(out.size() > UniqueId.LENGTH + MAX_GROUP_NAME_LENGTH + MAX_PUBLIC_KEY_LENGTH + MAX_AUTHOR_NAME_LENGTH @@ -133,18 +127,6 @@ public class ConstantsTest extends TestCase { assertTrue(out.size() <= MAX_PACKET_LENGTH); } - @Test - public void testEmptyBatch() throws Exception { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - BatchWriter b = new BatchWriterImpl(out, serial, writerFactory, - crypto.getMessageDigest()); - // There's not enough room for a message - b.setMaxPacketLength(4); - assertFalse(b.writeMessage(new byte[4])); - // Check that nothing was written - assertEquals(0, out.size()); - } - @Test public void testMessagesFitIntoLargeOffer() throws Exception { testMessagesFitIntoOffer(MAX_PACKET_LENGTH); @@ -158,25 +140,18 @@ public class ConstantsTest extends TestCase { private void testMessagesFitIntoOffer(int length) throws Exception { // Create an offer with as many message IDs as possible ByteArrayOutputStream out = new ByteArrayOutputStream(length); - OfferWriter o = new OfferWriterImpl(out, serial, writerFactory); - o.setMaxPacketLength(length); - while(o.writeMessageId(new MessageId(TestUtils.getRandomId()))); - o.finish(); + ProtocolWriter writer = protocolWriterFactory.createProtocolWriter(out); + int maxMessages = writer.getMaxMessagesForOffer(length); + Collection<MessageId> offered = new ArrayList<MessageId>(); + for(int i = 0; i < maxMessages; i++) { + offered.add(new MessageId(TestUtils.getRandomId())); + } + Offer o = packetFactory.createOffer(offered); + writer.writeOffer(o); // Check the size of the serialised offer assertTrue(out.size() <= length); } - @Test - public void testEmptyOffer() throws Exception { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - OfferWriter o = new OfferWriterImpl(out, serial, writerFactory); - // There's not enough room for a message ID - o.setMaxPacketLength(4); - assertFalse(o.writeMessageId(new MessageId(TestUtils.getRandomId()))); - // Check that nothing was written - assertEquals(0, out.size()); - } - @Test public void testSubscriptionsFitIntoUpdate() throws Exception { // Create the maximum number of maximum-length subscriptions @@ -190,9 +165,10 @@ public class ConstantsTest extends TestCase { // Add the subscriptions to an update ByteArrayOutputStream out = new ByteArrayOutputStream(MAX_PACKET_LENGTH); - SubscriptionUpdateWriter s = - new SubscriptionUpdateWriterImpl(out, writerFactory); - s.writeSubscriptions(subs, Long.MAX_VALUE); + ProtocolWriter writer = protocolWriterFactory.createProtocolWriter(out); + SubscriptionUpdate s = packetFactory.createSubscriptionUpdate(subs, + Long.MAX_VALUE); + writer.writeSubscriptionUpdate(s); // Check the size of the serialised update assertTrue(out.size() > MAX_GROUPS * (MAX_GROUP_NAME_LENGTH + MAX_PUBLIC_KEY_LENGTH + 8) + 8); @@ -218,9 +194,10 @@ public class ConstantsTest extends TestCase { // Add the transports to an update ByteArrayOutputStream out = new ByteArrayOutputStream(MAX_PACKET_LENGTH); - TransportUpdateWriter t = - new TransportUpdateWriterImpl(out, writerFactory); - t.writeTransports(transports, Long.MAX_VALUE); + ProtocolWriter writer = protocolWriterFactory.createProtocolWriter(out); + TransportUpdate t = packetFactory.createTransportUpdate(transports, + Long.MAX_VALUE); + writer.writeTransportUpdate(t); // Check the size of the serialised update assertTrue(out.size() > MAX_TRANSPORTS * (UniqueId.LENGTH + 4 + (MAX_PROPERTIES_PER_TRANSPORT * MAX_PROPERTY_LENGTH * 2)) diff --git a/test/net/sf/briar/protocol/ProtocolReadWriteTest.java b/test/net/sf/briar/protocol/ProtocolReadWriteTest.java index 121b25e4b6b3c64adcd6c51977bed9da55c67f92..880e45438eaf703bbacc839fc3afdc1bc6eefa25 100644 --- a/test/net/sf/briar/protocol/ProtocolReadWriteTest.java +++ b/test/net/sf/briar/protocol/ProtocolReadWriteTest.java @@ -17,23 +17,19 @@ import net.sf.briar.api.protocol.GroupFactory; import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.MessageFactory; import net.sf.briar.api.protocol.Offer; +import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.ProtocolReader; import net.sf.briar.api.protocol.ProtocolReaderFactory; +import net.sf.briar.api.protocol.ProtocolWriter; +import net.sf.briar.api.protocol.ProtocolWriterFactory; +import net.sf.briar.api.protocol.RawBatch; import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.SubscriptionUpdate; import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportUpdate; -import net.sf.briar.api.protocol.writers.AckWriter; -import net.sf.briar.api.protocol.writers.BatchWriter; -import net.sf.briar.api.protocol.writers.OfferWriter; -import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; -import net.sf.briar.api.protocol.writers.RequestWriter; -import net.sf.briar.api.protocol.writers.SubscriptionUpdateWriter; -import net.sf.briar.api.protocol.writers.TransportUpdateWriter; import net.sf.briar.crypto.CryptoModule; -import net.sf.briar.protocol.writers.ProtocolWritersModule; import net.sf.briar.serial.SerialModule; import org.junit.Test; @@ -45,6 +41,7 @@ public class ProtocolReadWriteTest extends TestCase { private final ProtocolReaderFactory readerFactory; private final ProtocolWriterFactory writerFactory; + private final PacketFactory packetFactory; private final BatchId batchId; private final Group group; private final Message message; @@ -58,10 +55,10 @@ public class ProtocolReadWriteTest extends TestCase { public ProtocolReadWriteTest() throws Exception { super(); Injector i = Guice.createInjector(new CryptoModule(), - new ProtocolModule(), new ProtocolWritersModule(), - new SerialModule()); + new ProtocolModule(), new SerialModule()); readerFactory = i.getInstance(ProtocolReaderFactory.class); writerFactory = i.getInstance(ProtocolWriterFactory.class); + packetFactory = i.getInstance(PacketFactory.class); batchId = new BatchId(TestUtils.getRandomId()); GroupFactory groupFactory = i.getInstance(GroupFactory.class); group = groupFactory.createGroup("Unrestricted group", null); @@ -83,53 +80,54 @@ public class ProtocolReadWriteTest extends TestCase { public void testWriteAndRead() throws Exception { // Write ByteArrayOutputStream out = new ByteArrayOutputStream(); + ProtocolWriter writer = writerFactory.createProtocolWriter(out); - AckWriter a = writerFactory.createAckWriter(out); - a.writeBatchId(batchId); - a.finish(); + Ack a = packetFactory.createAck(Collections.singletonList(batchId)); + writer.writeAck(a); - BatchWriter b = writerFactory.createBatchWriter(out); - b.writeMessage(message.getSerialised()); - b.finish(); + RawBatch b = packetFactory.createBatch(Collections.singletonList( + message.getSerialised())); + writer.writeBatch(b); - OfferWriter o = writerFactory.createOfferWriter(out); - o.writeMessageId(message.getId()); - o.finish(); + Offer o = packetFactory.createOffer(Collections.singletonList( + message.getId())); + writer.writeOffer(o); - RequestWriter r = writerFactory.createRequestWriter(out); - r.writeRequest(bitSet, 10); + Request r = packetFactory.createRequest(bitSet, 10); + writer.writeRequest(r); - SubscriptionUpdateWriter s = - writerFactory.createSubscriptionUpdateWriter(out); - s.writeSubscriptions(subscriptions, timestamp); + SubscriptionUpdate s = packetFactory.createSubscriptionUpdate( + subscriptions, timestamp); + writer.writeSubscriptionUpdate(s); - TransportUpdateWriter t = - writerFactory.createTransportUpdateWriter(out); - t.writeTransports(transports, timestamp); + TransportUpdate t = packetFactory.createTransportUpdate(transports, + timestamp); + writer.writeTransportUpdate(t); // Read ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); ProtocolReader reader = readerFactory.createProtocolReader(in); - Ack ack = reader.readAck(); - assertEquals(Collections.singletonList(batchId), ack.getBatchIds()); + a = reader.readAck(); + assertEquals(Collections.singletonList(batchId), a.getBatchIds()); - Batch batch = reader.readBatch().verify(); - assertEquals(Collections.singletonList(message), batch.getMessages()); + Batch b1 = reader.readBatch().verify(); + assertEquals(Collections.singletonList(message), b1.getMessages()); - Offer offer = reader.readOffer(); + o = reader.readOffer(); assertEquals(Collections.singletonList(message.getId()), - offer.getMessageIds()); + o.getMessageIds()); - Request request = reader.readRequest(); - assertEquals(bitSet, request.getBitmap()); + r = reader.readRequest(); + assertEquals(bitSet, r.getBitmap()); + assertEquals(10, r.getLength()); - SubscriptionUpdate subscriptionUpdate = reader.readSubscriptionUpdate(); - assertEquals(subscriptions, subscriptionUpdate.getSubscriptions()); - assertTrue(subscriptionUpdate.getTimestamp() == timestamp); + s = reader.readSubscriptionUpdate(); + assertEquals(subscriptions, s.getSubscriptions()); + assertEquals(timestamp, s.getTimestamp()); - TransportUpdate transportUpdate = reader.readTransportUpdate(); - assertEquals(transports, transportUpdate.getTransports()); - assertTrue(transportUpdate.getTimestamp() == timestamp); + t = reader.readTransportUpdate(); + assertEquals(transports, t.getTransports()); + assertEquals(timestamp, t.getTimestamp()); } } diff --git a/test/net/sf/briar/protocol/ProtocolWriterImplTest.java b/test/net/sf/briar/protocol/ProtocolWriterImplTest.java new file mode 100644 index 0000000000000000000000000000000000000000..07fe5377baee86a13765b9e37462229afa441307 --- /dev/null +++ b/test/net/sf/briar/protocol/ProtocolWriterImplTest.java @@ -0,0 +1,83 @@ +package net.sf.briar.protocol; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.BitSet; + +import junit.framework.TestCase; +import net.sf.briar.api.protocol.PacketFactory; +import net.sf.briar.api.protocol.ProtocolWriter; +import net.sf.briar.api.protocol.Request; +import net.sf.briar.api.serial.SerialComponent; +import net.sf.briar.api.serial.WriterFactory; +import net.sf.briar.crypto.CryptoModule; +import net.sf.briar.serial.SerialModule; +import net.sf.briar.util.StringUtils; + +import org.junit.Test; + +import com.google.inject.Guice; +import com.google.inject.Injector; + +public class ProtocolWriterImplTest extends TestCase { + + private final PacketFactory packetFactory; + private final SerialComponent serial; + private final WriterFactory writerFactory; + + public ProtocolWriterImplTest() { + super(); + Injector i = Guice.createInjector(new CryptoModule(), + new ProtocolModule(), new SerialModule()); + packetFactory = i.getInstance(PacketFactory.class); + serial = i.getInstance(SerialComponent.class); + writerFactory = i.getInstance(WriterFactory.class); + } + + @Test + public void testWriteBitmapNoPadding() throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + ProtocolWriter w = new ProtocolWriterImpl(serial, writerFactory, out); + BitSet b = new BitSet(); + // 11011001 = 0xD9 + b.set(0); + b.set(1); + b.set(3); + b.set(4); + b.set(7); + // 01011001 = 0x59 + b.set(9); + b.set(11); + b.set(12); + b.set(15); + Request r = packetFactory.createRequest(b, 16); + w.writeRequest(r); + // Short user tag 8, 0 as uint7, short bytes with length 2, 0xD959 + byte[] output = out.toByteArray(); + assertEquals("C8" + "00" + "92" + "D959", + StringUtils.toHexString(output)); + } + + @Test + public void testWriteBitmapWithPadding() throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + ProtocolWriter w = new ProtocolWriterImpl(serial, writerFactory, out); + BitSet b = new BitSet(); + // 01011001 = 0x59 + b.set(1); + b.set(3); + b.set(4); + b.set(7); + // 11011xxx = 0xD8, after padding + b.set(8); + b.set(9); + b.set(11); + b.set(12); + Request r = packetFactory.createRequest(b, 13); + w.writeRequest(r); + // Short user tag 8, 3 as uint7, short bytes with length 2, 0x59D8 + byte[] output = out.toByteArray(); + assertEquals("C8" + "03" + "92" + "59D8", + StringUtils.toHexString(output)); + } +} diff --git a/test/net/sf/briar/protocol/RequestReaderTest.java b/test/net/sf/briar/protocol/RequestReaderTest.java index 1b214e937951d2837f3f37570342a685a3eeee50..b972fba13231e8a70a2e39aeea42ec723d4ca66a 100644 --- a/test/net/sf/briar/protocol/RequestReaderTest.java +++ b/test/net/sf/briar/protocol/RequestReaderTest.java @@ -6,6 +6,7 @@ import java.util.BitSet; import junit.framework.TestCase; import net.sf.briar.api.FormatException; +import net.sf.briar.api.protocol.PacketFactory; import net.sf.briar.api.protocol.ProtocolConstants; import net.sf.briar.api.protocol.Request; import net.sf.briar.api.protocol.Types; @@ -13,6 +14,7 @@ import net.sf.briar.api.serial.Reader; import net.sf.briar.api.serial.ReaderFactory; import net.sf.briar.api.serial.Writer; import net.sf.briar.api.serial.WriterFactory; +import net.sf.briar.crypto.CryptoModule; import net.sf.briar.serial.SerialModule; import org.jmock.Expectations; @@ -26,20 +28,23 @@ public class RequestReaderTest extends TestCase { private final ReaderFactory readerFactory; private final WriterFactory writerFactory; + private final PacketFactory packetFactory; private final Mockery context; public RequestReaderTest() throws Exception { super(); - Injector i = Guice.createInjector(new SerialModule()); + Injector i = Guice.createInjector(new CryptoModule(), + new ProtocolModule(), new SerialModule()); readerFactory = i.getInstance(ReaderFactory.class); writerFactory = i.getInstance(WriterFactory.class); + packetFactory = i.getInstance(PacketFactory.class); context = new Mockery(); } @Test public void testFormatExceptionIfRequestIsTooLarge() throws Exception { - RequestFactory requestFactory = context.mock(RequestFactory.class); - RequestReader requestReader = new RequestReader(requestFactory); + PacketFactory packetFactory = context.mock(PacketFactory.class); + RequestReader requestReader = new RequestReader(packetFactory); byte[] b = createRequest(true); ByteArrayInputStream in = new ByteArrayInputStream(b); @@ -55,12 +60,12 @@ public class RequestReaderTest extends TestCase { @Test public void testNoFormatExceptionIfRequestIsMaximumSize() throws Exception { - final RequestFactory requestFactory = - context.mock(RequestFactory.class); - RequestReader requestReader = new RequestReader(requestFactory); + final PacketFactory packetFactory = context.mock(PacketFactory.class); + RequestReader requestReader = new RequestReader(packetFactory); final Request request = context.mock(Request.class); context.checking(new Expectations() {{ - oneOf(requestFactory).createRequest(with(any(BitSet.class))); + oneOf(packetFactory).createRequest(with(any(BitSet.class)), + with(any(int.class))); will(returnValue(request)); }}); @@ -96,8 +101,7 @@ public class RequestReaderTest extends TestCase { // Deserialise the request ByteArrayInputStream in = new ByteArrayInputStream(b); Reader reader = readerFactory.createReader(in); - RequestReader requestReader = - new RequestReader(new RequestFactoryImpl()); + RequestReader requestReader = new RequestReader(packetFactory); reader.addObjectReader(Types.REQUEST, requestReader); Request r = reader.readStruct(Types.REQUEST, Request.class); BitSet decoded = r.getBitmap(); @@ -116,10 +120,13 @@ public class RequestReaderTest extends TestCase { ByteArrayOutputStream out = new ByteArrayOutputStream(); Writer w = writerFactory.createWriter(out); w.writeStructId(Types.REQUEST); - // Allow one byte for the REQUEST tag, one byte for the BYTES tag, - // and five bytes for the length as an int32 - int size = ProtocolConstants.MAX_PACKET_LENGTH - 7; + // Allow one byte for the REQUEST tag, one byte for the padding length + // as a uint7, one byte for the BYTES tag, and five bytes for the + // length of the byte array as an int32 + int size = ProtocolConstants.MAX_PACKET_LENGTH - 8; if(tooBig) size++; + assertTrue(size > Short.MAX_VALUE); + w.writeUint7((byte) 0); w.writeBytes(new byte[size]); assertEquals(tooBig, out.size() > ProtocolConstants.MAX_PACKET_LENGTH); return out.toByteArray(); @@ -129,6 +136,7 @@ public class RequestReaderTest extends TestCase { ByteArrayOutputStream out = new ByteArrayOutputStream(); Writer w = writerFactory.createWriter(out); w.writeStructId(Types.REQUEST); + w.writeUint7((byte) 0); w.writeBytes(bitmap); return out.toByteArray(); } diff --git a/test/net/sf/briar/protocol/writers/RequestWriterImplTest.java b/test/net/sf/briar/protocol/writers/RequestWriterImplTest.java deleted file mode 100644 index 93ed7180c8b3ab74220536469acb6237046c72f7..0000000000000000000000000000000000000000 --- a/test/net/sf/briar/protocol/writers/RequestWriterImplTest.java +++ /dev/null @@ -1,70 +0,0 @@ -package net.sf.briar.protocol.writers; - -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.util.BitSet; - -import junit.framework.TestCase; -import net.sf.briar.api.protocol.writers.RequestWriter; -import net.sf.briar.api.serial.WriterFactory; -import net.sf.briar.serial.SerialModule; -import net.sf.briar.util.StringUtils; - -import org.junit.Test; - -import com.google.inject.Guice; -import com.google.inject.Injector; - -public class RequestWriterImplTest extends TestCase { - - private final WriterFactory writerFactory; - - public RequestWriterImplTest() { - super(); - Injector i = Guice.createInjector(new SerialModule()); - writerFactory = i.getInstance(WriterFactory.class); - } - - @Test - public void testWriteBitmapNoPadding() throws IOException { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - RequestWriter r = new RequestWriterImpl(out, writerFactory); - BitSet b = new BitSet(); - // 11011001 = 0xD9 - b.set(0); - b.set(1); - b.set(3); - b.set(4); - b.set(7); - // 01011001 = 0x59 - b.set(9); - b.set(11); - b.set(12); - b.set(15); - r.writeRequest(b, 16); - // Short user tag 8, short bytes with length 2, 0xD959 - byte[] output = out.toByteArray(); - assertEquals("C8" + "92" + "D959", StringUtils.toHexString(output)); - } - - @Test - public void testWriteBitmapWithPadding() throws IOException { - ByteArrayOutputStream out = new ByteArrayOutputStream(); - RequestWriter r = new RequestWriterImpl(out, writerFactory); - BitSet b = new BitSet(); - // 01011001 = 0x59 - b.set(1); - b.set(3); - b.set(4); - b.set(7); - // 11011xxx = 0xD8, after padding - b.set(8); - b.set(9); - b.set(11); - b.set(12); - r.writeRequest(b, 13); - // Short user tag 8, short bytes with length 2, 0x59D8 - byte[] output = out.toByteArray(); - assertEquals("C8" + "92" + "59D8", StringUtils.toHexString(output)); - } -} diff --git a/test/net/sf/briar/transport/ConnectionWriterTest.java b/test/net/sf/briar/transport/ConnectionWriterTest.java index 91718a68308297efd79a79f77467eb1b3b7aaa16..d4aa42103a3c5a9ce27eb46836f84693c13beb5d 100644 --- a/test/net/sf/briar/transport/ConnectionWriterTest.java +++ b/test/net/sf/briar/transport/ConnectionWriterTest.java @@ -16,7 +16,6 @@ import net.sf.briar.crypto.CryptoModule; import net.sf.briar.db.DatabaseModule; import net.sf.briar.lifecycle.LifecycleModule; import net.sf.briar.protocol.ProtocolModule; -import net.sf.briar.protocol.writers.ProtocolWritersModule; import net.sf.briar.serial.SerialModule; import net.sf.briar.transport.batch.TransportBatchModule; import net.sf.briar.transport.stream.TransportStreamModule; @@ -44,10 +43,9 @@ public class ConnectionWriterTest extends TestCase { }; Injector i = Guice.createInjector(testModule, new CryptoModule(), new DatabaseModule(), new LifecycleModule(), - new ProtocolModule(), new ProtocolWritersModule(), - new SerialModule(), new TestDatabaseModule(), - new TransportBatchModule(), new TransportModule(), - new TransportStreamModule()); + new ProtocolModule(), new SerialModule(), + new TestDatabaseModule(), new TransportBatchModule(), + new TransportModule(), new TransportStreamModule()); connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class); secret = new byte[32]; new Random().nextBytes(secret); diff --git a/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java b/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java index a25a9aeb5f8ad20b0fbae8c48fce88f14a64e19e..89c03178577ce577170ddce4674ea7ae0b3f97a4 100644 --- a/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java +++ b/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java @@ -26,11 +26,11 @@ import net.sf.briar.api.db.event.MessagesAddedEvent; import net.sf.briar.api.protocol.Message; import net.sf.briar.api.protocol.MessageFactory; import net.sf.briar.api.protocol.ProtocolReaderFactory; +import net.sf.briar.api.protocol.ProtocolWriterFactory; import net.sf.briar.api.protocol.Transport; import net.sf.briar.api.protocol.TransportId; import net.sf.briar.api.protocol.TransportIndex; import net.sf.briar.api.protocol.TransportUpdate; -import net.sf.briar.api.protocol.writers.ProtocolWriterFactory; import net.sf.briar.api.transport.BatchTransportReader; import net.sf.briar.api.transport.BatchTransportWriter; import net.sf.briar.api.transport.ConnectionContext; @@ -43,7 +43,6 @@ import net.sf.briar.db.DatabaseModule; import net.sf.briar.lifecycle.LifecycleModule; import net.sf.briar.plugins.ImmediateExecutor; import net.sf.briar.protocol.ProtocolModule; -import net.sf.briar.protocol.writers.ProtocolWritersModule; import net.sf.briar.serial.SerialModule; import net.sf.briar.transport.TransportModule; import net.sf.briar.transport.stream.TransportStreamModule; @@ -97,10 +96,9 @@ public class BatchConnectionReadWriteTest extends TestCase { }; return Guice.createInjector(testModule, new CryptoModule(), new DatabaseModule(), new LifecycleModule(), - new ProtocolModule(), new ProtocolWritersModule(), - new SerialModule(), new TestDatabaseModule(dir), - new TransportBatchModule(), new TransportModule(), - new TransportStreamModule()); + new ProtocolModule(), new SerialModule(), + new TestDatabaseModule(dir), new TransportBatchModule(), + new TransportModule(), new TransportStreamModule()); } @Test @@ -132,10 +130,10 @@ public class BatchConnectionReadWriteTest extends TestCase { alice.getInstance(ConnectionWriterFactory.class); ProtocolWriterFactory protoFactory = alice.getInstance(ProtocolWriterFactory.class); - BatchTransportWriter writer = new TestBatchTransportWriter(out); - OutgoingBatchConnection batchOut = new OutgoingBatchConnection( - connFactory, db, protoFactory, contactId, transportIndex, - writer); + BatchTransportWriter transport = new TestBatchTransportWriter(out); + OutgoingBatchConnection batchOut = new OutgoingBatchConnection(db, + connFactory, protoFactory, contactId, transportIndex, + transport); // Write whatever needs to be written batchOut.write(); // Close Alice's database @@ -188,7 +186,7 @@ public class BatchConnectionReadWriteTest extends TestCase { bob.getInstance(ProtocolReaderFactory.class); BatchTransportReader reader = new TestBatchTransportReader(in); IncomingBatchConnection batchIn = new IncomingBatchConnection( - new ImmediateExecutor(), connFactory, db, protoFactory, ctx, + new ImmediateExecutor(), db, connFactory, protoFactory, ctx, reader, tag); // No messages should have been added yet assertFalse(listener.messagesAdded);